[TensorExpr] Delet DimArg class. (#72390)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72390

This class didn't add much value and only caused more boilerplate code.
This change removes the class and updates all the use cases with
uses of `ExprHandle`.

A side effect of this change is different names in loop variables, which
caused massive mechanical changes in our tests.

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D34030296

Pulled By: ZolotukhinM

fbshipit-source-id: 2ba4e313506a43ab129a10d99e72b638b7d40108
(cherry picked from commit c2ec46a058)
This commit is contained in:
Mikhail Zolotukhin 2022-02-10 17:17:09 -08:00 committed by PyTorch MergeBot
parent 9123e9b3b5
commit 1855b14922
39 changed files with 948 additions and 1204 deletions

View File

@ -82,10 +82,8 @@ BENCHMARK_DEFINE_F(BatchNorm, NNC)(benchmark::State& state) {
VarHandle eps("eps", kFloat);
using axis = const VarHandle&;
Tensor output = Compute(
"output",
{{N_, "N"}, {C_, "C"}, {H_, "H"}, {W_, "W"}},
[&](axis n, axis c, axis h, axis w) {
Tensor output =
Compute("output", {N_, C_, H_, W_}, [&](axis n, axis c, axis h, axis w) {
// Compute affine terms.
auto inv_var = FloatImm::make(1.0f) / sqrt(var.load(c) + eps);
auto weight_v = weight.load(c);
@ -143,10 +141,8 @@ BENCHMARK_DEFINE_F(BatchNorm, NNCRelu)(benchmark::State& state) {
VarHandle eps("eps", kFloat);
using axis = const VarHandle&;
Tensor output = Compute(
"output",
{{N_, "N"}, {C_, "C"}, {H_, "H"}, {W_, "W"}},
[&](axis n, axis c, axis h, axis w) {
Tensor output =
Compute("output", {N_, C_, H_, W_}, [&](axis n, axis c, axis h, axis w) {
// Compute affine terms.
auto inv_var = FloatImm::make(1.0f) / sqrt(var.load(c) + eps);
auto weight_v = weight.load(c);

View File

@ -12,26 +12,21 @@ static void BM_CompileSwish(benchmark::State& state) {
constexpr int N = 512;
te::VarHandle n("n", te::kInt);
te::BufHandle A("A", {N}, te::kFloat);
te::Tensor relu =
te::Compute("relu", {{n, "n"}}, [&](const te::VarHandle& i) {
return te::Max::make(A.load(i), 0.f, false);
});
te::Tensor min6 =
te::Compute("min6", {{n, "n"}}, [&](const te::VarHandle& i) {
return te::Min::make(relu.load(i), 6.f, false);
});
te::Tensor plus3 =
te::Compute("plus3", {{n, "n"}}, [&](const te::VarHandle& i) {
return min6.load(i) + 3.f;
});
te::Tensor times =
te::Compute("times", {{n, "n"}}, [&](const te::VarHandle& i) {
return A.load(i) * plus3.load(i);
});
te::Tensor sixth =
te::Compute("sixth", {{n, "n"}}, [&](const te::VarHandle& i) {
return times.load(i) * 1.f / 6.f;
});
te::Tensor relu = te::Compute("relu", {n}, [&](const te::VarHandle& i) {
return te::Max::make(A.load(i), 0.f, false);
});
te::Tensor min6 = te::Compute("min6", {n}, [&](const te::VarHandle& i) {
return te::Min::make(relu.load(i), 6.f, false);
});
te::Tensor plus3 = te::Compute("plus3", {n}, [&](const te::VarHandle& i) {
return min6.load(i) + 3.f;
});
te::Tensor times = te::Compute("times", {n}, [&](const te::VarHandle& i) {
return A.load(i) * plus3.load(i);
});
te::Tensor sixth = te::Compute("sixth", {n}, [&](const te::VarHandle& i) {
return times.load(i) * 1.f / 6.f;
});
te::LoopNest nest({sixth}, {relu, min6, plus3, times, sixth});
for (auto tensor : {relu, min6, plus3, times}) {
nest.computeInline(tensor.buf());
@ -46,26 +41,20 @@ static void BM_CompileSwishLLVMOnly(benchmark::State& state) {
constexpr int N = 512;
te::VarHandle n("n", te::kInt);
te::BufHandle A("A", {N}, te::kFloat);
te::Tensor relu =
te::Compute("relu", {{n, "n"}}, [&](const te::VarHandle& i) {
return te::Max::make(A.load(i), 0.f, false);
});
te::Tensor min6 =
te::Compute("min6", {{n, "n"}}, [&](const te::VarHandle& i) {
return te::Min::make(relu.load(i), 6.f, false);
});
te::Tensor plus3 =
te::Compute("plus3", {{n, "n"}}, [&](const te::VarHandle& i) {
return min6.load(i) + 3.f;
});
te::Tensor times =
te::Compute("times", {{n, "n"}}, [&](const te::VarHandle& i) {
return A.load(i) * plus3.load(i);
});
te::Tensor sixth =
te::Compute("sixth", {{n, "n"}}, [&](const te::VarHandle& i) {
return times.load(i) * 1.f / 6.f;
});
te::Tensor relu = te::Compute("relu", {n}, [&](const te::VarHandle& i) {
return te::Max::make(A.load(i), 0.f, false);
});
te::Tensor min6 = te::Compute("min6", {n}, [&](const te::VarHandle& i) {
return te::Min::make(relu.load(i), 6.f, false);
});
te::Tensor plus3 = te::Compute(
"plus3", {n}, [&](const te::VarHandle& i) { return min6.load(i) + 3.f; });
te::Tensor times = te::Compute("times", {n}, [&](const te::VarHandle& i) {
return A.load(i) * plus3.load(i);
});
te::Tensor sixth = te::Compute("sixth", {n}, [&](const te::VarHandle& i) {
return times.load(i) * 1.f / 6.f;
});
te::LoopNest nest({sixth}, {relu, min6, plus3, times, sixth});
for (auto tensor : {relu, min6, plus3, times}) {
nest.computeInline(tensor.buf());

View File

@ -61,7 +61,7 @@ class ConcatBench : public benchmark::Fixture {
Tensor output = Compute(
"aten_cat",
{{output_size_[0], "M"}, {output_size_[1], "N"}},
{output_size_[0], output_size_[1]},
[&](const VarHandle& m, const VarHandle& n) {
int d = 0;
std::vector<int> cumulative_concat_dim_sizes(num_inputs);

View File

@ -44,12 +44,12 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprNoopt)(benchmark::State& state) {
te::BufHandle BP("B", {K, N}, te::kFloat);
te::Tensor CT = te::Reduce(
"gemm",
{{M, "M"}, {N, "N"}},
{M, N},
te::Sum(),
[&](const te::ExprHandle& m,
const te::ExprHandle& n,
const te::ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); },
{{K, "K"}});
{K});
te::LoopNest loop({CT});
loop.prepareForCodegen();
te::StmtPtr s = loop.root_stmt();
@ -66,12 +66,12 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile32x32)(benchmark::State& state) {
te::BufHandle BP("B", {K, N}, te::kFloat);
te::Tensor CT = te::Reduce(
"gemm",
{{M, "M"}, {N, "N"}},
{M, N},
te::Sum(),
[&](const te::ExprHandle& m,
const te::ExprHandle& n,
const te::ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); },
{{K, "K"}});
{K});
te::LoopNest loop({CT});
{
@ -124,12 +124,12 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16)(benchmark::State& state) {
te::BufHandle BP("B", {K, N}, te::kFloat);
te::Tensor CT = te::Reduce(
"gemm",
{{M, "M"}, {N, "N"}},
{M, N},
te::Sum(),
[&](const te::ExprHandle& m,
const te::ExprHandle& n,
const te::ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); },
{{K, "K"}});
{K});
te::LoopNest loop({CT});
{
@ -182,12 +182,12 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16VecUnroll)(benchmark::State& state) {
te::BufHandle BP("B", {K, N}, te::kFloat);
te::Tensor CT = te::Reduce(
"gemm",
{{M, "M"}, {N, "N"}},
{M, N},
te::Sum(),
[&](const te::ExprHandle& m,
const te::ExprHandle& n,
const te::ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); },
{{K, "K"}});
{K});
te::LoopNest loop({CT});
{
@ -248,12 +248,12 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16Cache)(benchmark::State& state) {
te::BufHandle BP("B", {K, N}, te::kFloat);
te::Tensor CT = te::Reduce(
"gemm",
{{M, "M"}, {N, "N"}},
{M, N},
te::Sum(),
[&](const te::ExprHandle& m,
const te::ExprHandle& n,
const te::ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); },
{{K, "K"}});
{K});
te::LoopNest loop({CT});
{

View File

@ -38,7 +38,7 @@ class ParallelAdd : public benchmark::Fixture {
BENCHMARK_DEFINE_F(ParallelAdd, Simple)(benchmark::State& state) {
BufHandle a_buf("a", {M}, kFloat);
BufHandle b_buf("b", {M}, kFloat);
Tensor c_tensor = Compute("c", {{M, "m"}}, [&](const VarHandle& m) {
Tensor c_tensor = Compute("c", {M}, [&](const VarHandle& m) {
return a_buf.load(m) + b_buf.load(m);
});
LoopNest loop_nest({c_tensor});

View File

@ -235,12 +235,12 @@ BENCHMARK_DEFINE_F(Reduce1D, TeNaive)(benchmark::State& state) {
te::BufHandle AP("A", {M}, te::kFloat);
te::Tensor BT = te::Reduce(
"reduce_full",
{{1, "N"}},
{1},
te::Sum(),
[&](const te::ExprHandle& n, const te::ExprHandle& m) {
return AP.load(m);
},
{{M, "M"}});
{M});
te::LoopNest loop({BT});
loop.prepareForCodegen();
@ -266,12 +266,12 @@ BENCHMARK_DEFINE_F(Reduce1D, TeSplitTail)(benchmark::State& state) {
te::BufHandle AP("A", {M}, te::kFloat);
te::Tensor BT = te::Reduce(
"reduce_full",
{{1, "N"}},
{1},
te::Sum(),
[&](const te::ExprHandle& n, const te::ExprHandle& m) {
return AP.load(m);
},
{{M, "M"}});
{M});
te::LoopNest loop({BT});
const int kChunkSize = 8;
@ -305,12 +305,12 @@ BENCHMARK_DEFINE_F(Reduce1D, TeSplitMask)(benchmark::State& state) {
te::BufHandle AP("A", {M}, te::kFloat);
te::Tensor BT = te::Reduce(
"reduce_full",
{{1, "N"}},
{1},
te::Sum(),
[&](const te::ExprHandle& n, const te::ExprHandle& m) {
return AP.load(m);
},
{{M, "M"}});
{M});
te::LoopNest loop({BT});
const int kChunkSize = 8;
@ -349,7 +349,7 @@ BENCHMARK_DEFINE_F(Reduce1D, TeRfactorV1)(benchmark::State& state) {
{},
te::Sum(),
[&](const te::ExprHandle& m) { return AP.load(m); },
{{M, "M"}});
{M});
te::LoopNest loop({BT});
te::BufPtr rfac_buf;

View File

@ -46,13 +46,13 @@ class SignedLog1pBench : public benchmark::Fixture {
"input", {input_size_int_[0], input_size_int_[1]}, kFloat);
Tensor abs_result = Compute(
"aten_abs",
{{input_size_int_[0], "M"}, {input_size_int_[1], "N"}},
{input_size_int_[0], input_size_int_[1]},
[&](const VarHandle& m, const VarHandle& n) {
return abs(input_ph.load(m, n));
});
Tensor log1p_result = Compute(
"aten_log1p",
{{input_size_int_[0], "M"}, {input_size_int_[1], "N"}},
{input_size_int_[0], input_size_int_[1]},
[&](const VarHandle& m, const VarHandle& n) {
return log1p(abs_result.load(m, n));
});
@ -60,7 +60,7 @@ class SignedLog1pBench : public benchmark::Fixture {
computeSign({input_ph}, {input_size_int_[0], input_size_int_[1]});
Tensor output = Compute(
"aten_mul",
{{input_size_int_[0], "M"}, {input_size_int_[1], "N"}},
{input_size_int_[0], input_size_int_[1]},
[&](const VarHandle& m, const VarHandle& n) {
return sign_result.load(m, n) * log1p_result.load(m, n);
});
@ -94,13 +94,13 @@ class SignedLog1pBench : public benchmark::Fixture {
"input", {input_size_int_[0], input_size_int_[1]}, kFloat);
Tensor abs_result = Compute(
"aten_abs",
{{input_size_int_[0], "M"}, {input_size_int_[1], "N"}},
{input_size_int_[0], input_size_int_[1]},
[&](const VarHandle& m, const VarHandle& n) {
return abs(input_ph.load(m, n));
});
Tensor log_vml_result = Compute(
"aten_log1p",
{{input_size_int_[0], "M"}, {input_size_int_[1], "N"}},
{input_size_int_[0], input_size_int_[1]},
[&](const VarHandle& m, const VarHandle& n) {
return log_vml(abs_result.load(m, n) + ExprHandle(1));
});
@ -108,7 +108,7 @@ class SignedLog1pBench : public benchmark::Fixture {
computeSign({input_ph}, {input_size_int_[0], input_size_int_[1]});
Tensor output = Compute(
"aten_mul",
{{input_size_int_[0], "M"}, {input_size_int_[1], "N"}},
{input_size_int_[0], input_size_int_[1]},
[&](const VarHandle& m, const VarHandle& n) {
return sign_result.load(m, n) * log_vml_result.load(m, n);
});

View File

@ -49,8 +49,7 @@ TEST(BoundsInference, _1) {
// {{b, kStore, 0, 99}, {a, kLoad, 0, 99}}
ExprHandle n(100);
BufHandle a("a", {n}, kFloat);
Tensor b =
Compute("b", {{n, "i"}}, [&](const VarHandle& i) { return a.load(i); });
Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); });
LoopNest l({b});
auto bounds_info = inferBounds(l.root_stmt());
@ -73,8 +72,7 @@ TEST(BoundsInference, _2) {
// {{b, kStore, 0, n-1}, {a, kLoad, 0, n-1}}
VarHandle n("n", kInt);
BufHandle a("a", {n}, kFloat);
Tensor b =
Compute("b", {{n, "i"}}, [&](const VarHandle& i) { return a.load(i); });
Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); });
LoopNest l({b});
auto bounds_info = inferBounds(l.root_stmt());
@ -97,9 +95,8 @@ TEST(BoundsInference, _3) {
// {{b, kStore, 0, 99}, {a, kLoad, 0, 109}}
ExprHandle n(100);
BufHandle a("a", {n + 10}, kFloat);
Tensor b = Compute("b", {{n, "i"}}, [&](const VarHandle& i) {
return a.load(i) * a.load(i + 10);
});
Tensor b = Compute(
"b", {n}, [&](const VarHandle& i) { return a.load(i) * a.load(i + 10); });
LoopNest l({b});
auto bounds_info = inferBounds(l.root_stmt());
@ -126,14 +123,12 @@ TEST(BoundsInference, _4) {
ExprHandle W(320);
ExprHandle H(200);
BufHandle a("a", {H, W}, kFloat);
Tensor b = Compute(
"b", {{H, "y"}, {W, "x"}}, [&](const VarHandle& y, const VarHandle& x) {
return x * y;
});
Tensor c = Compute(
"c", {{H, "y"}, {W, "x"}}, [&](const VarHandle& y, const VarHandle& x) {
return a.load(y, x) * b.load(y, x);
});
Tensor b = Compute("b", {H, W}, [&](const VarHandle& y, const VarHandle& x) {
return x * y;
});
Tensor c = Compute("c", {H, W}, [&](const VarHandle& y, const VarHandle& x) {
return a.load(y, x) * b.load(y, x);
});
LoopNest l({c});
std::vector<ForPtr> loops = l.getLoopStmtsFor(c);
StmtPtr body = l.getLoopBodyFor(c);
@ -204,8 +199,7 @@ TEST(BoundsInference, _5) {
// b[i_tail + (100/16)*16] = a[i_tail + (100/16)*16];
ExprHandle n(100);
BufHandle a("a", {n}, kFloat);
Tensor b =
Compute("b", {{n, "i"}}, [&](const VarHandle& i) { return a.load(i); });
Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); });
LoopNest l({b});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
@ -258,12 +252,11 @@ TEST(BoundsInference, _6) {
ExprHandle CW(32);
ExprHandle CH(20);
BufHandle a("a", {H, W}, kFloat);
Tensor b = Compute(
"b", {{H, "y"}, {W, "x"}}, [&](const VarHandle& y, const VarHandle& x) {
return x * y;
});
Tensor c = Compute(
"c", {{CH, "y"}, {CW, "x"}}, [&](const VarHandle& y, const VarHandle& x) {
Tensor b = Compute("b", {H, W}, [&](const VarHandle& y, const VarHandle& x) {
return x * y;
});
Tensor c =
Compute("c", {CH, CW}, [&](const VarHandle& y, const VarHandle& x) {
return a.load(y + 100, x + 100) * b.load(y * 2, x * 5);
});
LoopNest l({c});
@ -325,10 +318,9 @@ TEST(BoundsInference, _6) {
TEST(BoundsInference, Adjacent) {
ExprHandle H(6);
BufHandle a("a", {20}, kFloat);
Tensor b =
Compute("b", {{H, "x"}}, [&](const VarHandle& x) { return a.load(x); });
Tensor c = Compute(
"c", {{H, "x"}}, [&](const VarHandle& x) { return a.load(x + H); });
Tensor b = Compute("b", {H}, [&](const VarHandle& x) { return a.load(x); });
Tensor c =
Compute("c", {H}, [&](const VarHandle& x) { return a.load(x + H); });
LoopNest l({b, c});
std::vector<ForPtr> loops = NodeFinder<For>::find(l.root_stmt());
@ -383,12 +375,11 @@ TEST(BoundsInference, Adjacent) {
TEST(BoundsInference, MultipleTopLoopLoad) {
BufHandle a("a", {100}, kFloat);
Tensor b =
Compute("b", {{64, "x"}}, [&](const VarHandle& x) { return a.load(x); });
Tensor c = Compute(
"c", {{32, "x"}}, [&](const VarHandle& x) { return a.load(x + 10); });
Tensor d = Compute(
"d", {{96, "x"}}, [&](const VarHandle& x) { return a.load(x + 2); });
Tensor b = Compute("b", {64}, [&](const VarHandle& x) { return a.load(x); });
Tensor c =
Compute("c", {32}, [&](const VarHandle& x) { return a.load(x + 10); });
Tensor d =
Compute("d", {96}, [&](const VarHandle& x) { return a.load(x + 2); });
LoopNest l({b, c, d});
auto bounds_info = inferBounds(l.root_stmt());
@ -496,16 +487,15 @@ TEST(BoundsInference, MultipleTopLoopStore) {
}
TEST(BoundsInference, CacheReads) {
Tensor A = Compute(
"A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) {
return i * j;
});
Tensor B = Compute(
"B", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) {
Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
return i * j;
});
Tensor B =
Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
return A.load(i + 30, j + 3);
});
Tensor C = Compute(
"C", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) {
Tensor C =
Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
return A.load(i + 10, j + 20) + A.load(i + 30, j + 40);
});
@ -562,7 +552,7 @@ TEST(BoundsInference, CacheReads) {
TEST(BoundsInference, Flattened) {
Tensor b = Compute(
"b",
{{3, "z"}, {4, "y"}, {5, "x"}},
{3, 4, 5},
[&](const VarHandle& z, const VarHandle& y, const VarHandle& x) {
return x * y + z;
});
@ -637,14 +627,12 @@ TEST(BoundsInference, GetPotentialHazards) {
}
TEST(BoundsInference, GetPotentialHazardsLoopNoHazard) {
Tensor A = Compute(
"A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) {
return i * j;
});
Tensor B = Compute(
"B", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) {
return (i + 1) * (j + 1);
});
Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
return i * j;
});
Tensor B = Compute("B", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
return (i + 1) * (j + 1);
});
LoopNest l({A, B});
@ -663,12 +651,11 @@ TEST(BoundsInference, GetPotentialHazardsLoopNoHazard) {
}
TEST(BoundsInference, GetPotentialHazardsLoopCall) {
Tensor A = Compute(
"A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) {
return i * j;
});
Tensor B = Compute(
"B", {{64, "i"}, {64, "j"}}, [&](const VarHandle& i, const VarHandle& j) {
Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
return i * j;
});
Tensor B =
Compute("B", {64, 64}, [&](const VarHandle& i, const VarHandle& j) {
return A.load(i, j) + 5;
});
@ -688,10 +675,9 @@ TEST(BoundsInference, GetPotentialHazardsLoopCall) {
}
TEST(BoundsInference, GetPotentialHazardsLoopSplit) {
Tensor A = Compute(
"A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) {
return i * j;
});
Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
return i * j;
});
LoopNest l({A});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)

View File

@ -191,7 +191,7 @@ TEST(Conv, Conv2D) {
te::Tensor conv = te::Reduce(
"conv",
{{N, "n"}, {K, "k"}, {OH, "oh"}, {OW, "ow"}},
{N, K, OH, OW},
te::Sum(),
// FIXME: We have to use a `std::vector` parameter here and then unpack
// it, because we don't have an overload allowing for an arbitrary number
@ -211,7 +211,7 @@ TEST(Conv, Conv2D) {
},
// FIXME: If you forget one of the reduction dims, you get a segfault.
// Could that be caught by a verifier?
{{C, "c"}, {R, "r"}, {S, "s"}});
{C, R, S});
// FIXME: It'd be nice to have a single header that pulls in things like
// LoopNest, IRSimplifier, etc.

View File

@ -37,9 +37,9 @@ static void testCudaTestVectorAdd01_impl() {
Tensor c = Compute(
"c",
{
{num_iter, "n"},
{block_count, "b_id"},
{block_size, "t_id"},
num_iter,
block_count,
block_size,
},
[&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) {
return a_buf.load(n, b_id, t_id) + b_buf.load(n, b_id, t_id);
@ -101,9 +101,9 @@ TEST(Cuda, Sigmoid_CUDA) {
Tensor c = Compute(
"c",
{
{num_iter, "n"},
{block_count, "b_id"},
{block_size, "t_id"},
num_iter,
block_count,
block_size,
},
[&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) {
return sigmoid(sigmoid(a_buf.load(n, b_id, t_id)));
@ -163,12 +163,9 @@ TEST(Cuda, TestVectorAdd01_CUDA) {
static void testCudaTestVectorAdd02_impl(int64_t N, int64_t block_size) {
BufHandle a_buf("a", {N}, kFloat);
BufHandle b_buf("b", {N}, kFloat);
Tensor c = Compute(
"c",
{
{N, "N"},
},
[&](const VarHandle& n) { return a_buf.load(n) + b_buf.load(n); });
Tensor c = Compute("c", {N}, [&](const VarHandle& n) {
return a_buf.load(n) + b_buf.load(n);
});
LoopNest l({c});
ForPtr n_inner;
std::vector<ForPtr> loops = l.getLoopStmtsFor(c);
@ -222,7 +219,7 @@ TEST(Cuda, TestVectorAdd02_CUDA) {
TEST(Cuda, HalfCast_CUDA) {
auto half = ToDtype<at::Half>();
BufHandle a("a", {4}, half);
Tensor b = Compute("b", {{4, "n"}}, [&](const VarHandle& i) {
Tensor b = Compute("b", {4}, [&](const VarHandle& i) {
return Cast::make(kFloat, a.load(i));
});
@ -263,8 +260,8 @@ TEST(Cuda, DynamicShape2D_CUDA) {
VarHandle n("n", kInt);
BufHandle a("a", {m, n}, kFloat);
BufHandle b("b", {m, n}, kFloat);
Tensor c = Compute(
"c", {{m, "m"}, {n, "n"}}, [&](const VarHandle& i, const VarHandle& j) {
Tensor c =
Compute("c", {m, n}, [&](const VarHandle& i, const VarHandle& j) {
return a.load(i, j) + b.load(i, j);
});
LoopNest l({c});
@ -326,9 +323,9 @@ TEST(Cuda, TestRand01_CUDA) {
Tensor c = Compute(
"c",
{
{num_iter, "n"},
{block_count, "b_id"},
{block_size, "t_id"},
num_iter,
block_count,
block_size,
},
[&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) {
return Intrinsics::make(IntrinsicsOp::kRand, kFloat);
@ -381,8 +378,8 @@ TEST(Cuda, DynamicShapeSplit_CUDA) {
constexpr int64_t N = 4096;
VarHandle n("n", kLong);
BufHandle a("a", {n}, kFloat);
Tensor b = Compute(
"b", {{n, "n"}}, [&](const VarHandle& i) { return a.load(i) * 2.0f; });
Tensor b =
Compute("b", {n}, [&](const VarHandle& i) { return a.load(i) * 2.0f; });
LoopNest l({b});
ForPtr inner;
std::vector<ForPtr> loops = l.getLoopStmtsFor(b);
@ -914,15 +911,15 @@ TEST(Cuda, LocalMemReduce_1_CUDA) {
TEST(Cuda, HalfSupport_CUDA) {
auto half = ToDtype<at::Half>();
BufHandle a("a", {4}, half);
Tensor b = Compute("b", {{4, "n"}}, [&](const VarHandle& i) {
Tensor b = Compute("b", {4}, [&](const VarHandle& i) {
return Cast::make(half, ExprHandle(2.0f) * a.load(i));
});
Tensor c = Compute("c", {{4, "n"}}, [&](const VarHandle& i) {
Tensor c = Compute("c", {4}, [&](const VarHandle& i) {
return Cast::make(kFloat, Cast::make(half, ExprHandle(42)) + b.load(i));
});
Tensor d = Compute("d", {{4, "n"}}, [&](const VarHandle& i) {
Tensor d = Compute("d", {4}, [&](const VarHandle& i) {
return Cast::make(half, c.load(i));
});
@ -971,7 +968,7 @@ TEST(Cuda, HalfSupport_CUDA) {
TEST(Cuda, HalfPropagation_CUDA) {
auto half = ToDtype<at::Half>();
BufHandle a("a", {4}, half);
Tensor relu = Compute("relu", {{4, "n"}}, [&](const VarHandle& i) {
Tensor relu = Compute("relu", {4}, [&](const VarHandle& i) {
return Max::make(a.load(i), ExprHandle(alloc<HalfImm>(0)), true);
});
@ -987,8 +984,8 @@ TEST(Cuda, HalfPropagation_CUDA) {
const std::string& verification_pattern =
R"IR(
# CHECK: for (
# CHECK: float v = float(a[n]);
# CHECK: relu[n] = half(Max(v, 0.f
# CHECK: float v = float(a[i]);
# CHECK: relu[i] = half(Max(v, 0.f
# CHECK: })IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
@ -1020,7 +1017,7 @@ TEST(Cuda, UnusedHalfArgument_CUDA) {
BufHandle a("a", {4}, kFloat);
auto half = ToDtype<at::Half>();
BufHandle b("b", {4}, half);
Tensor relu = Compute("relu", {{4, "n"}}, [&](const VarHandle& i) {
Tensor relu = Compute("relu", {4}, [&](const VarHandle& i) {
return Max::make(a.load(i), ExprHandle(alloc<FloatImm>(0)), true);
});
@ -1036,8 +1033,8 @@ TEST(Cuda, UnusedHalfArgument_CUDA) {
const std::string& verification_pattern =
R"IR(
# CHECK: for (
# CHECK: float v = a[n];
# CHECK: relu[n] = Max(v, 0.f
# CHECK: float v = a[i];
# CHECK: relu[i] = Max(v, 0.f
# CHECK: })IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
@ -1150,10 +1147,9 @@ TEST(Cuda, MaskBlockDim_CUDA) {
int B_SIZE = 50;
BufHandle a_buf("a", {A_SIZE}, kFloat);
BufHandle b_buf("b", {B_SIZE}, kFloat);
Tensor c = Compute("c", {{A_SIZE, "i"}}, [&](const VarHandle& i) {
return a_buf.load(i) + 10;
});
Tensor d = Compute("d", {{B_SIZE, "i"}}, [&](const VarHandle& i) {
Tensor c = Compute(
"c", {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; });
Tensor d = Compute("d", {B_SIZE}, [&](const VarHandle& i) {
return a_buf.load(i) + b_buf.load(i);
});
@ -1242,10 +1238,9 @@ TEST(Cuda, MaskThreadDim_CUDA) {
int B_SIZE = 100;
BufHandle a_buf("a", {A_SIZE}, kFloat);
BufHandle b_buf("b", {B_SIZE}, kFloat);
Tensor c = Compute("c", {{A_SIZE, "i"}}, [&](const VarHandle& i) {
return a_buf.load(i) + 10;
});
Tensor d = Compute("d", {{B_SIZE, "i"}}, [&](const VarHandle& i) {
Tensor c = Compute(
"c", {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; });
Tensor d = Compute("d", {B_SIZE}, [&](const VarHandle& i) {
return a_buf.load(i / 2) + b_buf.load(i);
});
@ -1336,10 +1331,9 @@ TEST(Cuda, MaskMultiBlockDim_CUDA) {
int B_SIZE = 50;
BufHandle a_buf("a", {A_SIZE}, kFloat);
BufHandle b_buf("b", {B_SIZE}, kFloat);
Tensor c = Compute("c", {{A_SIZE, "i"}}, [&](const VarHandle& i) {
return a_buf.load(i) + 10;
});
Tensor d = Compute("d", {{B_SIZE, "i"}}, [&](const VarHandle& i) {
Tensor c = Compute(
"c", {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; });
Tensor d = Compute("d", {B_SIZE}, [&](const VarHandle& i) {
return a_buf.load(i) + b_buf.load(i);
});
@ -1429,10 +1423,9 @@ TEST(Cuda, MaskBlockAndThreadDim_CUDA) {
int B_SIZE = 50;
BufHandle a_buf("a", {A_SIZE}, kFloat);
BufHandle b_buf("b", {B_SIZE}, kFloat);
Tensor c = Compute("c", {{A_SIZE, "i"}}, [&](const VarHandle& i) {
return a_buf.load(i) + 10;
});
Tensor d = Compute("d", {{B_SIZE, "i"}}, [&](const VarHandle& i) {
Tensor c = Compute(
"c", {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; });
Tensor d = Compute("d", {B_SIZE}, [&](const VarHandle& i) {
return a_buf.load(i) + b_buf.load(i);
});
@ -1522,15 +1515,11 @@ TEST(Cuda, MaskMultiDim_CUDA) {
BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat);
BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat);
Tensor c = Compute(
"C",
{{OUTER_SIZE, "i"}, {A_SIZE, "j"}},
[&](const VarHandle& i, const VarHandle& j) {
"C", {OUTER_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) {
return ExprHandle(2) * a_buf.load(i, j);
});
Tensor d = Compute(
"D",
{{OUTER_SIZE, "i"}, {B_SIZE, "j"}},
[&](const VarHandle& i, const VarHandle& j) {
"D", {OUTER_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) {
return c.load(i, j * 2) + b_buf.load(i, j);
});
@ -1651,15 +1640,11 @@ TEST(Cuda, MaskMultiDimSymbolic_CUDA) {
BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat);
BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat);
Tensor c = Compute(
"C",
{{OUTER_SIZE, "i"}, {A_SIZE, "j"}},
[&](const VarHandle& i, const VarHandle& j) {
"C", {OUTER_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) {
return ExprHandle(2) * a_buf.load(i, j);
});
Tensor d = Compute(
"D",
{{OUTER_SIZE, "i"}, {B_SIZE, "j"}},
[&](const VarHandle& i, const VarHandle& j) {
"D", {OUTER_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) {
return c.load(i, j * 2) + b_buf.load(i, j);
});
@ -2062,15 +2047,11 @@ TEST(Cuda, MaskMultiDimMultiAxis_CUDA) {
BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat);
BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat);
Tensor c = Compute(
"C",
{{OUTER_SIZE, "i"}, {A_SIZE, "j"}},
[&](const VarHandle& i, const VarHandle& j) {
"C", {OUTER_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) {
return ExprHandle(2) * a_buf.load(i, j);
});
Tensor d = Compute(
"D",
{{OUTER_SIZE, "i"}, {B_SIZE, "j"}},
[&](const VarHandle& i, const VarHandle& j) {
"D", {OUTER_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) {
return c.load(i, j * 2) + b_buf.load(i, j);
});
@ -2192,15 +2173,11 @@ TEST(Cuda, MaskMultiDimMultiLevel_CUDA) {
BufHandle a_buf("a", {OUTER_A_SIZE, A_SIZE}, kFloat);
BufHandle b_buf("b", {OUTER_B_SIZE, B_SIZE}, kFloat);
Tensor c = Compute(
"C",
{{OUTER_A_SIZE, "i"}, {A_SIZE, "j"}},
[&](const VarHandle& i, const VarHandle& j) {
"C", {OUTER_A_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) {
return ExprHandle(2) * a_buf.load(i, j);
});
Tensor d = Compute(
"D",
{{OUTER_B_SIZE, "i"}, {B_SIZE, "j"}},
[&](const VarHandle& i, const VarHandle& j) {
"D", {OUTER_B_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) {
return c.load(i, j * 2) + b_buf.load(i, j);
});

View File

@ -777,14 +777,14 @@ TEST(ExternalCall, ComputeInterop) {
Tensor Input = Compute(
"Input",
{{1, "n"}, {16, "c"}, {32, "h"}, {32, "w"}},
{1, 16, 32, 32},
[&](const VarHandle& n,
const VarHandle& c,
const VarHandle& h,
const VarHandle& w) { return FloatImm::make(5.0f); });
Tensor Weight = Compute(
"Weight",
{{16, "n"}, {16, "c"}, {1, "kh"}, {1, "kw"}},
{16, 16, 1, 1},
[&](const VarHandle& n,
const VarHandle& c,
const VarHandle& h,
@ -806,7 +806,7 @@ TEST(ExternalCall, ComputeInterop) {
{}));
Tensor Result = Compute(
"Result",
{{1, "n"}, {16, "c"}, {32, "h"}, {32, "w"}},
{1, 16, 32, 32},
[&](const VarHandle& n,
const VarHandle& c,
const VarHandle& h,
@ -866,14 +866,12 @@ TEST(ExternalCall, Inlining) {
BufHandle MatmulResultBuf("MatmulResult", {8, 8}, kFloat);
Tensor A = Compute(
"A", {{8, "i"}, {8, "j"}}, [&](const VarHandle& i, const VarHandle& j) {
return FloatImm::make(5.0f);
});
Tensor B = Compute(
"B", {{8, "i"}, {8, "j"}}, [&](const VarHandle& i, const VarHandle& j) {
return FloatImm::make(4.0f);
});
Tensor A = Compute("A", {8, 8}, [&](const VarHandle& i, const VarHandle& j) {
return FloatImm::make(5.0f);
});
Tensor B = Compute("B", {8, 8}, [&](const VarHandle& i, const VarHandle& j) {
return FloatImm::make(4.0f);
});
Tensor MatmulResult = Tensor(
MatmulResultBuf.node(),
ExternalCall::make(
@ -881,10 +879,8 @@ TEST(ExternalCall, Inlining) {
"nnc_aten_matmul",
{BufHandle(A.buf()), BufHandle(B.buf())},
{}));
Tensor Result = Compute(
"Result",
{{8, "i"}, {8, "j"}},
[&](const VarHandle& i, const VarHandle& j) {
Tensor Result =
Compute("Result", {8, 8}, [&](const VarHandle& i, const VarHandle& j) {
return MatmulResult.load(i, j) + FloatImm::make(3.0f);
});

View File

@ -53,42 +53,36 @@ TEST(IRPrinter, FunctionName) {
int N = 20;
Tensor producer = Compute(
"producer",
{{M, "m"}, {N, "n"}},
[&](const ExprHandle& m, const ExprHandle& n) { return m * n; });
"producer", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
return m * n;
});
Tensor chunk_0 = Compute(
"chunk",
{{M, "m"}, {N / 2, "n"}},
[&](const ExprHandle& m, const ExprHandle& n) {
"chunk_0", {M, N / 2}, [&](const ExprHandle& m, const ExprHandle& n) {
return producer.load(m, n);
});
Tensor chunk_1 = Compute(
"chunk",
{{M, "m"}, {N / 2, "n"}},
[&](const ExprHandle& m, const ExprHandle& n) {
"chunk_1", {M, N / 2}, [&](const ExprHandle& m, const ExprHandle& n) {
return producer.load(m, n + ExprHandle(N / 2));
});
Tensor consumer = Compute(
"consumer",
{{M, "i"}, {N / 2, "j"}},
[&](const ExprHandle& i, const ExprHandle& j) {
"consumer", {M, N / 2}, [&](const ExprHandle& i, const ExprHandle& j) {
return i * chunk_1.load(i, j);
});
LoopNest l({chunk_0, chunk_1, consumer});
auto body = l.root_stmt();
auto body = LoopNest::sanitizeNames(l.root_stmt());
std::stringstream ss;
ss << *body;
const std::string& verification_pattern =
R"IR(
# CHECK: for (int i
# CHECK: for (int j
# CHECK: consumer[i, j] = i * (chunk_1[i, j])IR";
# CHECK: for (int i_2
# CHECK: for (int j_2
# CHECK: consumer[i_2, j_2] = i_2 * (chunk_1[i_2, j_2])IR";
torch::jit::testing::FileCheck().run(verification_pattern, ss.str());
}

View File

@ -1602,7 +1602,7 @@ Tensor lowerNanToNum(
auto input_buf = c10::get<BufHandle>(inputs[0]);
auto e = Compute(
"custom_nan_to_num",
c10::fmap<DimArg>(outputShape),
outputShape,
[&](const std::vector<VarHandle>& axes) {
std::vector<ExprHandle> indices(axes.begin(), axes.end());
auto load = input_buf.load(indices);

View File

@ -584,8 +584,7 @@ DOUBLE_INTRINSICS_TEST(lgamma, 4)
TEST(LLVM, VectorizerLoadStoreTest) {
BufHandle a("A", {1}, kInt);
Tensor c =
Compute("c", {{4, "i"}}, [&](const VarHandle& i) { return a.load(i); });
Tensor c = Compute("c", {4}, [&](const VarHandle& i) { return a.load(i); });
BufHandle c_buf(c.buf());
LoopNest l({c});
@ -606,7 +605,7 @@ TEST(LLVM, VectorizerLoadStoreTest) {
TEST(LLVM, VectorizeBitCast) {
BufHandle a("A", {128}, kInt);
Tensor c = Compute("c", {{128, "i"}}, [&](const VarHandle& i) {
Tensor c = Compute("c", {128}, [&](const VarHandle& i) {
return bitcast<float>(a.load(i));
});
@ -1186,9 +1185,8 @@ TEST(LLVM, StoreFloat) {
TEST(LLVM, SimpleMath01) {
const int N = 1024;
Tensor tensor = Compute("f", {{N, "i"}}, [](const VarHandle& i) {
return cast<float>(i * i + 1);
});
Tensor tensor = Compute(
"f", {N}, [](const VarHandle& i) { return cast<float>(i * i + 1); });
LoopNest l({tensor});
StmtPtr stmt = l.root_stmt();
BufHandle f_buf(tensor.buf());
@ -1209,9 +1207,8 @@ TEST(LLVM, ComputeMul) {
const int N = 1024;
BufHandle a("a", {N}, kFloat);
BufHandle b("b", {N}, kFloat);
Tensor c = Compute("c", {{N, "i"}}, [&](const VarHandle& i) {
return a.load(i) * b.load(i);
});
Tensor c = Compute(
"c", {N}, [&](const VarHandle& i) { return a.load(i) * b.load(i); });
BufHandle c_buf(c.buf());
LoopNest l({c});
@ -1232,10 +1229,9 @@ TEST(LLVM, BroadcastAdd) {
const int N = 1024;
BufHandle a("a", {M, N}, kFloat);
BufHandle b("b", {N}, kFloat);
Tensor c = Compute(
"c", {{M, "i"}, {N, "j"}}, [&](const VarHandle& i, const VarHandle& j) {
return a.load(i, j) + b.load(j);
});
Tensor c = Compute("c", {M, N}, [&](const VarHandle& i, const VarHandle& j) {
return a.load(i, j) + b.load(j);
});
BufHandle c_buf(c.buf());
LoopNest l({c});
@ -1333,9 +1329,8 @@ TEST(LLVM, TensorDynamicShapeAdd) {
VarHandle n("n", kInt);
BufHandle a("a", {n}, kFloat);
BufHandle b("b", {n}, kFloat);
Tensor c = Compute("c", {{n, "n"}}, [&](const VarHandle& i) {
return a.load(i) + b.load(i);
});
Tensor c = Compute(
"c", {n}, [&](const VarHandle& i) { return a.load(i) + b.load(i); });
LoopNest l({c});
StmtPtr s = l.root_stmt();
LLVMCodeGen cg(s, {a, b, c, n});
@ -1356,8 +1351,8 @@ TEST(LLVM, DynamicShape2D) {
VarHandle n("n", kInt);
BufHandle a("a", {m, n}, kFloat);
BufHandle b("b", {m, n}, kFloat);
Tensor c = Compute(
"c", {{m, "m"}, {n, "n"}}, [&](const VarHandle& i, const VarHandle& j) {
Tensor c =
Compute("c", {m, n}, [&](const VarHandle& i, const VarHandle& j) {
return a.load(i, j) + b.load(i, j);
});
LoopNest l({c});
@ -1386,7 +1381,7 @@ TEST(LLVM, EmptyStmt) {
TEST(LLVM, EliminatedStmt) {
BufHandle a("a", {1}, kFloat);
Tensor c = Compute("c", {{0, "m"}}, [&](const VarHandle& m) { return m; });
Tensor c = Compute("c", {0}, [&](const VarHandle& m) { return m; });
LoopNest l({c});
l.prepareForCodegen();
@ -1405,10 +1400,7 @@ TEST(LLVM, SimpleReduction) {
BufHandle a("a", {1, M, N}, kFloat);
// TODO: why doesn't implicit vector<DimArg> work?
std::vector<DimArg> axis = {DimArg(1)};
std::vector<DimArg> reduce_axis = {DimArg(M), DimArg(N)};
Tensor b = Reduce("sum", axis, Sum(), a, reduce_axis);
Tensor b = Reduce("sum", {1}, Sum(), a, {M, N});
LoopNest loop({b});
loop.prepareForCodegen();
@ -1442,10 +1434,7 @@ TEST(LLVM, RFactorReduction) {
BufHandle a("a", {1, M, N}, kFloat);
// TODO: why doesn't implicit vector<DimArg> work?
std::vector<DimArg> axis = {DimArg(1)};
std::vector<DimArg> reduce_axis = {DimArg(M), DimArg(N)};
Tensor b = Reduce("sum", axis, Sum(), a, reduce_axis);
Tensor b = Reduce("sum", {1}, Sum(), a, {M, N});
LoopNest loop({b});
std::vector<ForPtr> loops = loop.getLoopStmtsFor(b);
@ -1490,7 +1479,7 @@ TEST(LLVM, RFactorVectorizedReduction) {
BufHandle a("a", {1, M, N}, kFloat);
Tensor b = Reduce("sum", {{1, "K"}}, Sum(), a, {{M, "M"}, {N, "N"}});
Tensor b = Reduce("sum", {1}, Sum(), a, {M, N});
LoopNest loopnest({b});
std::vector<ForPtr> loops = loopnest.getLoopStmtsFor(b);
// Reorder n and m loops
@ -1536,10 +1525,9 @@ static void testSimpleParallel() {
// parallel or sequential.
const int M = 4;
const int N = 6;
Tensor f = Compute(
"f", {{M, "m"}, {N, "n"}}, [](const VarHandle& m, const VarHandle& n) {
return cast<float>(m + n);
});
Tensor f = Compute("f", {M, N}, [](const VarHandle& m, const VarHandle& n) {
return cast<float>(m + n);
});
LoopNest loop_nest({f});
auto const& loops = loop_nest.getLoopStmtsFor(f);
ForPtr m = loops[0];
@ -1588,20 +1576,14 @@ TEST(LLVM, CompositeParallel) {
for (const auto test_cfg : c10::irange(test_count)) {
int M = 5;
int N = 7;
Tensor t1 =
Compute("t1", {{M, "M"}}, [](const VarHandle& m) { return m + 1.f; });
Tensor t2 =
Compute("t2", {{N, "N"}}, [](const VarHandle& n) { return n + 2.f; });
Tensor t3 = Compute(
"t3",
{{M, "M"}, {N, "N"}},
[=](const VarHandle& m, const VarHandle& n) {
Tensor t1 = Compute("t1", {M}, [](const VarHandle& m) { return m + 1.f; });
Tensor t2 = Compute("t2", {N}, [](const VarHandle& n) { return n + 2.f; });
Tensor t3 =
Compute("t3", {M, N}, [=](const VarHandle& m, const VarHandle& n) {
return t1.load(m) * t2.load(n);
});
Tensor t4 = Compute(
"t4",
{{M, "M"}, {N, "N"}},
[=](const VarHandle& m, const VarHandle& n) {
Tensor t4 =
Compute("t4", {M, N}, [=](const VarHandle& m, const VarHandle& n) {
return t3.load(m, n) + m + n;
});
LoopNest loop_nest({t4}, {t1, t2, t3, t4});
@ -1657,12 +1639,12 @@ TEST(LLVM, VectorizedGEMM) {
BufHandle BP("B", {K, N}, kFloat);
Tensor CT = Reduce(
"gemm",
{{M, "M"}, {N, "N"}},
{M, N},
Sum(),
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
return AP.load(m, k) * BP.load(k, n);
},
{{K, "K"}});
{K});
LoopNest loop({CT});
{
@ -1735,10 +1717,9 @@ TEST(LLVM, CallRaw) {
VarHandle N("N", kInt);
BufHandle a("a", {M, N}, kFloat);
BufHandle b("b", {N}, kFloat);
Tensor c = Compute(
"c", {{M, "i"}, {N, "j"}}, [&](const VarHandle& i, const VarHandle& j) {
return a.load(i, j) + b.load(j);
});
Tensor c = Compute("c", {M, N}, [&](const VarHandle& i, const VarHandle& j) {
return a.load(i, j) + b.load(j);
});
LoopNest l({c});
l.prepareForCodegen();
@ -1776,7 +1757,7 @@ TEST(LLVM, CustomTarget) {
BufHandle a("a", {M}, kFloat);
BufHandle b("b", {M}, kFloat);
BufHandle c("c", {M}, kFloat);
Tensor d = Compute("d", {{M, "m"}}, [&](const VarHandle& m) {
Tensor d = Compute("d", {M}, [&](const VarHandle& m) {
return a.load(m) * b.load(m) + c.load(m);
});
LoopNest nest({d});

File diff suppressed because it is too large Load Diff

View File

@ -2696,13 +2696,13 @@ TEST(MemDependency, MemDependencyCheckerComputeAPI) {
BufHandle b_buf("b", {5, 6}, kFloat);
Tensor c = Compute(
"broadcast_add",
{{4, "m"}, {5, "n"}, {6, "k"}},
{4, 5, 6},
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
return a_buf.load(m, n) + b_buf.load(n, k);
});
Tensor d = Compute(
"d",
{{4, "m"}, {5, "n"}, {6, "k"}},
{4, 5, 6},
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
return c.load(m, n, k) + 1;
});
@ -2741,13 +2741,13 @@ TEST(MemDependency, MemDependencyCheckerComputeInline) {
BufHandle b_buf("b", {5, 6}, kFloat);
Tensor c = Compute(
"broadcast_add",
{{4, "m"}, {5, "n"}, {6, "k"}},
{4, 5, 6},
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
return a_buf.load(m, n) + b_buf.load(n, k);
});
Tensor d = Compute(
"d",
{{4, "m"}, {5, "n"}, {6, "k"}},
{4, 5, 6},
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
return c.load(m, n, k) + 1;
});
@ -2776,7 +2776,7 @@ TEST(MemDependency, MemDependencyCheckerComputeSplit) {
BufHandle b_buf("b", {5, 6}, kFloat);
Tensor c = Compute(
"broadcast_add",
{{4, "m"}, {5, "n"}, {6, "k"}},
{4, 5, 6},
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
return a_buf.load(m, n) + b_buf.load(n, k);
});
@ -2822,7 +2822,7 @@ TEST(MemDependency, MemDependencyCheckerComputeReorder) {
BufHandle b_buf("b", {5, 6}, kFloat);
Tensor c = Compute(
"broadcast_add",
{{4, "m"}, {5, "n"}, {6, "k"}},
{4, 5, 6},
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
return a_buf.load(m, n) + b_buf.load(n, k);
});
@ -2888,11 +2888,11 @@ TEST(MemDependency, MemDependencyCheckerComputeReduce) {
Tensor c = Compute(
"scale",
{{2, "l2"}, {3, "n1"}, {6, "m1"}},
{2, 3, 6},
[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
return b.load(l, n, m) * a.load(l, n, m);
});
Tensor d = Reduce("sum", {{2, "l1"}}, Sum(), c, {{3, "n1"}, {6, "m1"}});
Tensor d = Reduce("sum", {2}, Sum(), c, {3, 6});
LoopNest l({d}, {c, d});
MemDependencyChecker analyzer({a.node(), b.node()}, {d.buf()});
@ -2924,12 +2924,12 @@ TEST(MemDependency, MemDependencyCheckerComputeGEMM) {
BufHandle BP("B", {K, N}, kFloat);
Tensor CT = Reduce(
"gemm",
{{M, "M"}, {N, "N"}},
{M, N},
Sum(),
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
return AP.load(m, k) * BP.load(k, n);
},
{{K, "K"}});
{K});
LoopNest loop({CT});
{

View File

@ -95,30 +95,24 @@ TEST(MemPlanning, SameBufSizeMemReuse) {
Tensor CT = Reduce(
"gemm",
{{M, "M"}, {N, "N"}},
{M, N},
Sum(),
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
return AP.load(m, k) * BP.load(k, n);
},
{{K, "K"}});
Tensor DT = Compute(
"relu",
{{M, "M"}, {N, "N"}},
[&](const ExprHandle& m, const ExprHandle& n) {
{K});
Tensor DT =
Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
auto zero = Cast::make(CT.buf()->dtype(), 0);
return CompareSelect::make(
CT.load(m, n), zero, zero, CT.load(m, n), kLT);
});
Tensor ET = Compute(
"add",
{{M, "M"}, {N, "N"}},
[&](const ExprHandle& m, const ExprHandle& n) {
Tensor ET =
Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
return DT.load(m, n) + DT.load(m, n);
});
Tensor FT = Compute(
"mul",
{{M, "M"}, {N, "N"}},
[&](const ExprHandle& m, const ExprHandle& n) {
Tensor FT =
Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
return ET.load(m, n) * ET.load(m, n);
});
auto stmt = Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()});
@ -188,36 +182,28 @@ TEST(MemPlanning, SameBufSizeMultiMemReuses) {
Tensor CT = Reduce(
"gemm",
{{M, "M"}, {N, "N"}},
{M, N},
Sum(),
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
return AP.load(m, k) * BP.load(k, n);
},
{{K, "K"}});
Tensor DT = Compute(
"relu",
{{M, "M"}, {N, "N"}},
[&](const ExprHandle& m, const ExprHandle& n) {
{K});
Tensor DT =
Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
auto zero = Cast::make(CT.buf()->dtype(), 0);
return CompareSelect::make(
CT.load(m, n), zero, zero, CT.load(m, n), kLT);
});
Tensor ET = Compute(
"add",
{{M, "M"}, {N, "N"}},
[&](const ExprHandle& m, const ExprHandle& n) {
Tensor ET =
Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
return DT.load(m, n) + DT.load(m, n);
});
Tensor FT = Compute(
"mul",
{{M, "M"}, {N, "N"}},
[&](const ExprHandle& m, const ExprHandle& n) {
Tensor FT =
Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
return ET.load(m, n) * ET.load(m, n);
});
Tensor GT = Compute(
"sub",
{{M, "M"}, {N, "N"}},
[&](const ExprHandle& m, const ExprHandle& n) {
Tensor GT =
Compute("sub", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
return FT.load(m, n) - ET.load(m, n);
});
@ -296,42 +282,32 @@ TEST(MemPlanning, SameBufSizeMultiMemReusesOfOneBuf) {
Tensor CT = Reduce(
"gemm",
{{M, "M"}, {N, "N"}},
{M, N},
Sum(),
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
return AP.load(m, k) * BP.load(k, n);
},
{{K, "K"}});
Tensor DT = Compute(
"relu",
{{M, "M"}, {N, "N"}},
[&](const ExprHandle& m, const ExprHandle& n) {
{K});
Tensor DT =
Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
auto zero = Cast::make(CT.buf()->dtype(), 0);
return CompareSelect::make(
CT.load(m, n), zero, zero, CT.load(m, n), kLT);
});
Tensor ET = Compute(
"add",
{{M, "M"}, {N, "N"}},
[&](const ExprHandle& m, const ExprHandle& n) {
Tensor ET =
Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
return DT.load(m, n) + DT.load(m, n);
});
Tensor FT = Compute(
"mul",
{{M, "M"}, {N, "N"}},
[&](const ExprHandle& m, const ExprHandle& n) {
Tensor FT =
Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
return ET.load(m, n) * ET.load(m, n);
});
Tensor GT = Compute(
"sub",
{{M, "M"}, {N, "N"}},
[&](const ExprHandle& m, const ExprHandle& n) {
Tensor GT =
Compute("sub", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
return FT.load(m, n) - 1;
});
Tensor HT = Compute(
"div",
{{M, "M"}, {N, "N"}},
[&](const ExprHandle& m, const ExprHandle& n) {
Tensor HT =
Compute("div", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
return GT.load(m, n) / 2;
});
@ -418,30 +394,24 @@ TEST(MemPlanning, SmallerBufSizeNonMemReuse) {
Tensor CT = Reduce(
"gemm",
{{M, "M"}, {N, "N"}},
{M, N},
Sum(),
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
return AP.load(m, k) * BP.load(k, n);
},
{{K, "K"}});
Tensor DT = Compute(
"relu",
{{M, "M"}, {N, "N"}},
[&](const ExprHandle& m, const ExprHandle& n) {
{K});
Tensor DT =
Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
auto zero = Cast::make(CT.buf()->dtype(), 0);
return CompareSelect::make(
CT.load(m, n), zero, zero, CT.load(m, n), kLT);
});
Tensor ET = Compute(
"add",
{{M * 2, "EM"}, {N * 2, "EN"}},
[&](const ExprHandle& em, const ExprHandle& en) {
"add", {M * 2, N * 2}, [&](const ExprHandle& em, const ExprHandle& en) {
return DT.load(em / 2, en / 2) + DT.load(em / 2, en / 2);
});
Tensor FT = Compute(
"mul",
{{M * 2, "FM"}, {N * 2, "FN"}},
[&](const ExprHandle& fm, const ExprHandle& fn) {
"mul", {M * 2, N * 2}, [&](const ExprHandle& fm, const ExprHandle& fn) {
return ET.load(fm, fn) * ET.load(fm, fn);
});
auto stmt = Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()});

View File

@ -35,7 +35,7 @@ TEST(Reductions, ReduceSum0D_1) {
std::vector<float> out(M, -1.f);
Tensor c = Reduce("sum", {{M, "m"}}, Sum(), b, {});
Tensor c = Reduce("sum", {M}, Sum(), b, {});
LoopNest loop({c});
loop.prepareForCodegen();
StmtPtr s = loop.root_stmt();
@ -80,7 +80,7 @@ TEST(Reductions, ReduceSum1D) {
std::vector<float> out(1, -1.f);
Tensor c = Reduce("sum", {}, Sum(), b, {{10, "m"}});
Tensor c = Reduce("sum", {}, Sum(), b, {10});
LoopNest loop({c});
loop.prepareForCodegen();
StmtPtr s = loop.root_stmt();
@ -109,7 +109,7 @@ TEST(Reductions, ReduceSum2D) {
std::vector<float> out(M, -1.f);
Tensor c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}});
Tensor c = Reduce("sum", {M}, Sum(), b, {N});
LoopNest loop({c});
loop.prepareForCodegen();
StmtPtr s = loop.root_stmt();
@ -138,7 +138,7 @@ TEST(Reductions, ReduceSum3D) {
BufHandle b("b", {2, 3, m}, kFloat);
Tensor c = Reduce("sum", {{2, "l"}, {3, "n"}}, Sum(), b, {{m, "m"}});
Tensor c = Reduce("sum", {2, 3}, Sum(), b, {m});
LoopNest loop({c});
loop.prepareForCodegen();
StmtPtr s = loop.root_stmt();
@ -168,7 +168,7 @@ TEST(Reductions, ReduceSum3D) {
ASSERT_EQ(cData[i], expected);
}
Tensor d = Reduce("sum2", {{2, "l"}}, Sum(), b, {{3, "n"}, {m, "m"}});
Tensor d = Reduce("sum2", {2}, Sum(), b, {3, m});
LoopNest loop2({d});
loop2.prepareForCodegen();
StmtPtr s2 = loop2.root_stmt();
@ -186,7 +186,7 @@ TEST(Reductions, ReduceSum3D) {
// This is the same as just reducing the original result across that axis.
BufHandle c_buf(c.buf());
Tensor e = Reduce("sum3", {{2, "l"}}, Sum(), c_buf, {{3, "m"}});
Tensor e = Reduce("sum3", {2}, Sum(), c_buf, {3});
LoopNest loop3({e});
loop3.prepareForCodegen();
StmtPtr s3 = loop3.root_stmt();
@ -210,12 +210,7 @@ TEST(Reductions, ReduceSum10D) {
std::vector<float> in(InputSize, 1.f);
std::vector<float> out(OutputSize, -1.f);
Tensor c = Reduce(
"sum",
{{2, "a"}, {3, "b"}, {2, "c"}, {3, "d"}, {2, "e"}},
Sum(),
in_,
{{3, "f"}, {2, "g"}, {3, "h"}, {2, "i"}, {3, "j"}});
Tensor c = Reduce("sum", {2, 3, 2, 3, 2}, Sum(), in_, {3, 2, 3, 2, 3});
LoopNest loop({c});
loop.prepareForCodegen();
StmtPtr s = loop.root_stmt();
@ -250,7 +245,7 @@ TEST(Reductions, ReduceProduct) {
Reducer product(
ExprHandle(1.f), [](ExprHandle a, ExprHandle b) { return a * b; });
Tensor c = Reduce("product", {{M, "m"}}, product, b, {{N, "n"}});
Tensor c = Reduce("product", {M}, product, b, {N});
LoopNest loop({c});
loop.prepareForCodegen();
StmtPtr s = loop.root_stmt();
@ -281,7 +276,7 @@ TEST(Reductions, ReduceMax) {
in[j] = j;
}
Tensor dm1 = Reduce("max", {}, Maximum(kFloat), in_, {{10, "m"}});
Tensor dm1 = Reduce("max", {}, Maximum(kFloat), in_, {10});
LoopNest loop({dm1});
loop.prepareForCodegen();
@ -296,7 +291,7 @@ TEST(Reductions, ReduceMax) {
BufHandle in2_("b", {2, 5}, kFloat);
std::vector<float> out2(2, -1.f);
Tensor m2d = Reduce("max", {{2, "n"}}, Maximum(kFloat), in2_, {{5, "m"}});
Tensor m2d = Reduce("max", {2}, Maximum(kFloat), in2_, {5});
LoopNest loop2({m2d});
loop2.prepareForCodegen();
@ -326,7 +321,7 @@ TEST(Reductions, ReduceMinCustomInitializer) {
{},
Minimum(ExprHandle(minInit)),
[&](ParameterList& v) { return in_.load(v); },
{{10, "m"}});
{10});
LoopNest loop({min});
loop.prepareForCodegen();
@ -357,12 +352,12 @@ TEST(Reductions, ReduceAnyAll) {
Tensor any = Reduce(
"anyEqual",
{{4, "i"}},
{4},
anyEqSV,
[&](const auto& i, const auto& j) {
return CompareSelect::make(b.load(i, j), searchValue, kEQ);
},
{{10, "j"}});
{10});
LoopNest loop({any});
loop.prepareForCodegen();
@ -400,12 +395,12 @@ TEST(Reductions, ReduceAnyAll) {
Tensor allGreaterThan = Reduce(
"allGreaterThan",
{{4, "i"}},
{4},
allGTSV,
[&](const auto& i, const auto& j) {
return CompareSelect::make(b.load(i, j), searchValue, kGT);
},
{{10, "j"}});
{10});
LoopNest loop2({allGreaterThan});
loop2.prepareForCodegen();
@ -448,12 +443,12 @@ TEST(Reductions, ReduceMatmul2D) {
Tensor mm = Reduce(
"mm",
{{3, "m"}, {3, "n"}},
{3, 3},
Sum(),
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
return tA.load(m, k) * tB.load(k, n);
},
{{2, "k"}});
{2});
LoopNest loop({mm});
loop.prepareForCodegen();
@ -480,10 +475,10 @@ TEST(Reductions, ReduceRfactorLike) {
std::vector<float> in_rf_(10, -2.f);
std::vector<float> out(1, -1.f);
Tensor l1 = Reduce("l1", {{10, "i"}}, Sum(), in, {{10, "j"}});
Tensor l1 = Reduce("l1", {10}, Sum(), in, {10});
BufHandle in_rf(l1.buf());
Tensor l2 = Reduce("l2", {}, Sum(), in_rf, {{10, "i"}});
Tensor l2 = Reduce("l2", {}, Sum(), in_rf, {10});
LoopNest loop({l1, l2});
loop.prepareForCodegen();
@ -503,11 +498,9 @@ TEST(Reductions, ReduceAsProducer) {
BufHandle a("a", {2, 3}, kFloat);
BufHandle b("b", {2, 3, m}, kFloat);
Tensor c = Reduce("sum", {{2, "l1"}, {3, "n1"}}, Sum(), b, {{m, "m1"}});
Tensor d = Compute(
"scale",
{{2, "l2"}, {3, "n1"}},
[&](const VarHandle& l, const VarHandle& n) {
Tensor c = Reduce("sum", {2, 3}, Sum(), b, {m});
Tensor d =
Compute("scale", {2, 3}, [&](const VarHandle& l, const VarHandle& n) {
return c.load(l, n) * a.load(l, n);
});
LoopNest loop({d}, {c, d});
@ -548,11 +541,11 @@ TEST(Reductions, ReduceAsConsumer) {
Tensor c = Compute(
"scale",
{{2, "l2"}, {3, "n1"}, {m, "m1"}},
{2, 3, m},
[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
return b.load(l, n, m) * a.load(l, n, m);
});
Tensor d = Reduce("sum", {{2, "l1"}}, Sum(), c, {{3, "n1"}, {m, "m1"}});
Tensor d = Reduce("sum", {2}, Sum(), c, {3, m});
LoopNest loop({d}, {c, d});
loop.prepareForCodegen();
StmtPtr s = loop.root_stmt();
@ -599,7 +592,7 @@ TEST(Reductions, SplitReduceAxis) {
}
std::vector<float> out(16, -1.f);
Tensor tensor = Reduce("sum", {{16, "m"}}, Sum(), in, {{8, "n"}});
Tensor tensor = Reduce("sum", {16}, Sum(), in, {8});
LoopNest l({tensor});
std::vector<ForPtr> loops = l.getLoopStmtsFor(tensor);
LoopNest::splitWithTail(loops[1], 2);
@ -627,7 +620,7 @@ TEST(Reductions, SplitNonReduceAxis) {
}
}
std::vector<float> out(16, -1.f);
Tensor tensor = Reduce("sum", {{16, "m"}}, Sum(), in, {{8, "n"}});
Tensor tensor = Reduce("sum", {16}, Sum(), in, {8});
LoopNest l({tensor});
std::vector<ForPtr> loops = l.getLoopStmtsFor(tensor);
LoopNest::splitWithTail(loops[0], 2);
@ -657,14 +650,14 @@ TEST(Reductions, ReorderedReductionInitializer) {
BufHandle in("in", {1, 12, 6}, kFloat);
std::vector<float> in_(12 * 6, 1.f);
Tensor tensor_ = Reduce("sum", {{1, "k"}, {12, "n"}}, Sum(), in, {{6, "m"}});
Tensor tensor_ = Reduce("sum", {1, 12}, Sum(), in, {6});
LoopNest l_({tensor_});
l_.prepareForCodegen();
StmtPtr s_ = Stmt::clone(l_.root_stmt());
s_ = IRSimplifier::simplify(s_);
Tensor tensor = Reduce("sum", {{1, "k"}, {12, "n"}}, Sum(), in, {{6, "m"}});
Tensor tensor = Reduce("sum", {1, 12}, Sum(), in, {6});
LoopNest l({tensor});
auto loops = l.getLoopStmtsFor(tensor);
@ -709,7 +702,7 @@ TEST(Reductions, ReduceRfactor) {
std::vector<float> out(1, -1.f);
Tensor c = Reduce("sum", {}, Sum(), b, {{m, "m"}, {n, "n"}});
Tensor c = Reduce("sum", {}, Sum(), b, {m, n});
LoopNest loop({c});
std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
auto c_body = loop.getAllWritesToBuf(c.buf())[1];
@ -742,7 +735,7 @@ TEST(Reductions, Reduce3DRfactorInner) {
std::vector<float> out(1, -1.f);
Tensor c = Reduce("sum", {}, Sum(), b, {{m, "m"}, {n, "n"}, {k, "k"}});
Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k});
LoopNest loop({c});
std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
auto c_body = loop.getAllWritesToBuf(c.buf())[1];
@ -775,7 +768,7 @@ TEST(Reductions, Reduce3DRfactorOuter) {
std::vector<float> out(1, -1.f);
Tensor c = Reduce("sum", {}, Sum(), b, {{m, "m"}, {n, "n"}, {k, "k"}});
Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k});
LoopNest loop({c});
std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
auto c_body = loop.getAllWritesToBuf(c.buf())[1];
@ -799,12 +792,7 @@ TEST(Reductions, ReduceRepeatedInternalRfactor) {
std::vector<float> out(1, -1.f);
std::vector<float> ref(1, -1.f);
Tensor c = Reduce(
"sum",
{},
Sum(),
in_,
{{2, "a"}, {3, "b"}, {4, "c"}, {5, "d"}, {6, "e"}});
Tensor c = Reduce("sum", {}, Sum(), in_, {2, 3, 4, 5, 6});
LoopNest orig_loop({c});
// Try rfactoring N outer loops
@ -850,7 +838,7 @@ TEST(Reductions, ReduceSplitTail) {
for (const auto i : c10::irange(3)) {
std::vector<float> out(M, -1.f);
Tensor c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}});
Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
LoopNest loop({c});
std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
LoopNest::splitWithTail(loops[i], 8);
@ -880,7 +868,7 @@ TEST(Reductions, ReduceSplitNoTail) {
for (const auto i : c10::irange(3)) {
std::vector<float> out(M, -1.f);
Tensor c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}});
Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
LoopNest loop({c});
std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
LoopNest::splitWithTail(loops[i], 5);
@ -912,7 +900,7 @@ TEST(Reductions, ReduceOverSplitTail) {
for (const auto i : c10::irange(3)) {
std::vector<float> out(M, -1.f);
Tensor c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}});
Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
LoopNest loop({c});
std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
LoopNest::splitWithTail(loops[i], 16);
@ -943,7 +931,7 @@ TEST(Reductions, ReduceSplitMask) {
for (const auto i : c10::irange(3)) {
std::vector<float> out(M, -1.f);
Tensor c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}});
Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
LoopNest loop({c});
std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
LoopNest::splitWithMask(loops[i], 8);
@ -973,7 +961,7 @@ TEST(Reductions, ReduceSplitNoMask) {
for (const auto i : c10::irange(3)) {
std::vector<float> out(M, -1.f);
Tensor c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}});
Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
LoopNest loop({c});
std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
LoopNest::splitWithMask(loops[i], 5);
@ -1004,7 +992,7 @@ TEST(Reductions, ReduceOverSplitMask) {
for (const auto i : c10::irange(3)) {
std::vector<float> out(M, -1.f);
Tensor c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}});
Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
LoopNest loop({c});
std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
LoopNest::splitWithMask(loops[i], 16);
@ -1038,7 +1026,7 @@ TEST(Reductions, ReduceSplitRfactor) {
std::vector<float> out(M, -1.f);
Tensor c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}});
Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
LoopNest loop({c});
std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
LoopNest::splitWithTail(loops[2], SPLIT_FACTOR);
@ -1078,7 +1066,7 @@ TEST(Reductions, ReduceOverSplitRfactor) {
std::vector<float> out(1, -1.f);
Tensor c = Reduce("sum", {}, Sum(), b, {{N, "n"}, {K, "k"}});
Tensor c = Reduce("sum", {}, Sum(), b, {N, K});
LoopNest loop({c});
std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
@ -1128,10 +1116,9 @@ TEST(Reductions, ReduceInlineReduction) {
BufHandle a_buf("a", {M}, kFloat);
BufHandle b_buf("b", {M, N, K}, kFloat);
Tensor x = Reduce("x", {{M, "m1"}}, Sum(), b_buf, {{N, "n1"}, {K, "k1"}});
Tensor y = Compute("y", {{M, "m2"}}, [&](const VarHandle& m) {
return a_buf.load(m) + x.load(m);
});
Tensor x = Reduce("x", {M}, Sum(), b_buf, {N, K});
Tensor y = Compute(
"y", {M}, [&](const VarHandle& m) { return a_buf.load(m) + x.load(m); });
PaddedBuffer<float> a_v(M);
PaddedBuffer<float> b_v(M, N, K);
@ -1162,11 +1149,11 @@ TEST(Reductions, ReduceInlineConsumer) {
Tensor x = Compute(
"x",
{{M, "m1"}, {N, "n1"}, {K, "k1"}},
{M, N, K},
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
return a_buf.load(m, n, k) + b_buf.load(m, n, k);
});
Tensor y = Reduce("y", {{M, "m2"}}, Sum(), x, {{N, "n2"}, {K, "k2"}});
Tensor y = Reduce("y", {M}, Sum(), x, {N, K});
PaddedBuffer<float> a_v(M, N, K);
PaddedBuffer<float> b_v(M, N, K);
@ -1215,7 +1202,7 @@ TEST(Reductions, ReduceInlineReducerInternal) {
Tensor x = Compute(
"x",
{{M, "m1"}, {N, "n1"}, {K, "k1"}},
{M, N, K},
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
return a_buf.load(m, n, k) + b_buf.load(m, n, k);
});
@ -1223,7 +1210,7 @@ TEST(Reductions, ReduceInlineReducerInternal) {
Reducer minimum(ExprHandle(0.f), [&](ExprHandle a, ExprHandle b) {
return Add::make(ExprHandle(1.f), Min::make(a, b, false));
});
Tensor y = Reduce("y", {{M, "m2"}}, minimum, x, {{N, "n2"}, {K, "k2"}});
Tensor y = Reduce("y", {M}, minimum, x, {N, K});
PaddedBuffer<float> a_v(M, N, K);
PaddedBuffer<float> b_v(M, N, K);
@ -1272,26 +1259,28 @@ TEST(Reductions, ReductionCacheAccessesOperatorAxis) {
Tensor c = Compute(
"scale",
{{L, "l2"}, {N, "n1"}, {M, "m1"}},
{L, N, M},
[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
return b.load(l, n, m) * a.load(l, n, m);
});
Tensor d = Reduce("sum", {{L, "l1"}}, Sum(), c, {{N, "n1"}, {M, "m1"}});
Tensor d = Reduce("sum", {L}, Sum(), c, {N, M});
Tensor e = Compute("scale", {{L, "l"}}, [&](const VarHandle& l) {
Tensor e = Compute("scale", {L}, [&](const VarHandle& l) {
return b.load(0, 0, l) * d.load(l);
});
LoopNest l({e}, {c, d, e});
LoopNest l_before(l);
l_before.prepareForCodegen();
SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e});
SimpleIREvaluator cg_before(
LoopNest::sanitizeNames(l_before.root_stmt()), {a, b, e});
StmtPtr d_loop = l.getLoopStmtsFor(d)[0];
l.cacheAccesses(d.buf(), "d_local", d_loop);
l.prepareForCodegen();
StmtPtr result = IRSimplifier::simplify(l.root_stmt());
StmtPtr result =
LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
SimpleIREvaluator cg_after(result, {a, b, e});
std::ostringstream oss;
@ -1299,16 +1288,16 @@ TEST(Reductions, ReductionCacheAccessesOperatorAxis) {
const std::string& expected_ir =
R"IR(
#CHECK: Allocate(d_local); // dtype=float, dims=[4]
#CHECK: for (int l1
#CHECK: d_local[l1] = 0.f
#CHECK: for (int n1
#CHECK: for (int m1
#CHECK: d_local[l1] = (d_local[l1]) + (scale[
#CHECK: for (int i_2
#CHECK: d_local[i_2] = 0.f
#CHECK: for (int
#CHECK: for (int
#CHECK: d_local[i_2] = (d_local[i_2]) + (scale[
#CHECK: }
#CHECK: }
#CHECK: }
#CHECK: for (int i
#CHECK: sum[i] = d_local[i]
#CHECK: for (int i_3
#CHECK: sum[i_3] = d_local[i_3]
#CHECK: Free(d_local);
#CHECK-NOT: d_local
)IR";
@ -1347,13 +1336,13 @@ TEST(Reductions, ReductionCacheAccessesOuterReduceAxis) {
Tensor c = Compute(
"scale",
{{L, "l2"}, {N, "n1"}, {M, "m1"}},
{L, N, M},
[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
return b.load(l, n, m) * a.load(l, n, m);
});
Tensor d = Reduce("sum", {{L, "l1"}}, Sum(), c, {{N, "n1"}, {M, "m1"}});
Tensor d = Reduce("sum", {L}, Sum(), c, {N, M});
Tensor e = Compute("scale", {{L, "l"}}, [&](const VarHandle& l) {
Tensor e = Compute("scale", {L}, [&](const VarHandle& l) {
return b.load(0, 0, l) * d.load(l);
});
@ -1366,7 +1355,8 @@ TEST(Reductions, ReductionCacheAccessesOuterReduceAxis) {
l.cacheAccesses(d.buf(), "d_local", d_loop);
l.prepareForCodegen();
StmtPtr result = IRSimplifier::simplify(l.root_stmt());
StmtPtr result =
LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
SimpleIREvaluator cg_after(result, {a, b, e});
std::ostringstream oss;
@ -1374,14 +1364,14 @@ TEST(Reductions, ReductionCacheAccessesOuterReduceAxis) {
const std::string& expected_ir =
R"IR(
#CHECK: Allocate(d_local); // dtype=float, dims=[1]
#CHECK: sum[l1] = 0
#CHECK: d_local[0] = sum[l1]
#CHECK: for (int n1
#CHECK: for (int m1
#CHECK: sum[i_1] = 0
#CHECK: d_local[0] = sum[i_1]
#CHECK: for (int j_1
#CHECK: for (int k_1
#CHECK: d_local[0] = (d_local[0]) + (scale[
#CHECK: }
#CHECK: }
#CHECK: sum[l1] = d_local[0]
#CHECK: sum[i_1] = d_local[0]
#CHECK: Free(d_local);
#CHECK-NOT: d_local
)IR";
@ -1420,13 +1410,13 @@ TEST(Reductions, ReductionCacheAccessesInnerReduceAxis) {
Tensor c = Compute(
"scale",
{{L, "l2"}, {N, "n1"}, {M, "m1"}},
{L, N, M},
[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
return b.load(l, n, m) * a.load(l, n, m);
});
Tensor d = Reduce("sum", {{L, "l1"}}, Sum(), c, {{N, "n1"}, {M, "m1"}});
Tensor d = Reduce("sum", {L}, Sum(), c, {N, M});
Tensor e = Compute("scale", {{L, "l"}}, [&](const VarHandle& l) {
Tensor e = Compute("scale", {L}, [&](const VarHandle& l) {
return b.load(0, 0, l) * d.load(l);
});
@ -1439,7 +1429,8 @@ TEST(Reductions, ReductionCacheAccessesInnerReduceAxis) {
l.cacheAccesses(d.buf(), "d_local", d_loop);
l.prepareForCodegen();
StmtPtr result = IRSimplifier::simplify(l.root_stmt());
StmtPtr result =
LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
SimpleIREvaluator cg_after(result, {a, b, e});
std::ostringstream oss;
@ -1447,13 +1438,13 @@ TEST(Reductions, ReductionCacheAccessesInnerReduceAxis) {
const std::string& expected_ir =
R"IR(
#CHECK: Allocate(d_local); // dtype=float, dims=[1]
#CHECK: sum[l1] = 0
#CHECK: for (int n1
#CHECK: sum[i_1] = 0
#CHECK: for (int
#CHECK: d_local[0] = 0
#CHECK: for (int m1
#CHECK: for (int
#CHECK: d_local[0] = (d_local[0]) + (scale[
#CHECK: }
#CHECK: sum[l1] = (sum[l1]) + (d_local[0])
#CHECK: sum[i_1] = (sum[i_1]) + (d_local[0])
#CHECK: }
#CHECK: Free(d_local);
#CHECK-NOT: d_local
@ -1489,13 +1480,13 @@ TEST(Reductions, ReductionCacheBodyAccess) {
Tensor c = Compute(
"scale",
{{24, "l2"}, {32, "n1"}, {12, "m1"}},
{24, 32, 12},
[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
return b.load(l, n, m) * a.load(l, n, m);
});
Tensor d = Reduce("sum", {{24, "l1"}}, Sum(), c, {{32, "n1"}, {12, "m1"}});
Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12});
Tensor e = Compute("scale", {{24, "l"}}, [&](const VarHandle& l) {
Tensor e = Compute("scale", {24}, [&](const VarHandle& l) {
return b.load(0, 0, l) * d.load(l);
});
@ -1505,7 +1496,8 @@ TEST(Reductions, ReductionCacheBodyAccess) {
l.cacheAccesses(c.buf(), "scale_local", d_loop);
l.prepareForCodegen();
StmtPtr result = IRSimplifier::simplify(l.root_stmt());
StmtPtr result =
LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
SimpleIREvaluator cg(result, {a, b, e});
std::ostringstream oss;
@ -1513,11 +1505,11 @@ TEST(Reductions, ReductionCacheBodyAccess) {
const std::string& expected_ir =
R"IR(
#CHECK: Allocate(scale_local); // dtype=float, dims=[1, 32, 12]
#CHECK: for (int j = 0; j < 32; j++) {
#CHECK: for (int k = 0; k < 12; k++) {
#CHECK: scale_local[k + 12 * j] = scale[(k + 12 * j) + 384 * l1];
#CHECK: sum[l1] = (sum[l1]) + (scale_local[m1_1 + 12 * n1_1]);
#CHECK: scale_1[l] = (b[l]) * (sum[l]);
#CHECK: for (int j_1 = 0; j_1 < 32; j_1++) {
#CHECK: for (int k_1 = 0; k_1 < 12; k_1++) {
#CHECK: scale_local[k_1 + 12 * j_1] = scale[(k_1 + 12 * j_1) + 384 * i_1];
#CHECK: sum[i_1] = (sum[i_1]) + (scale_local[k_2 + 12 * j_2]);
#CHECK: scale_1[i_2] = (b[i_2]) * (sum[i_2]);
#CHECK: Free(scale_local);
)IR";
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
@ -1529,13 +1521,13 @@ TEST(Reductions, ReductionCacheConsumerAccess) {
Tensor c = Compute(
"scale",
{{24, "l2"}, {32, "n1"}, {12, "m1"}},
{24, 32, 12},
[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
return b.load(l, n, m) * a.load(l, n, m);
});
Tensor d = Reduce("sum", {{24, "l1"}}, Sum(), c, {{32, "n1"}, {12, "m1"}});
Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12});
Tensor e = Compute("scale", {{24, "l"}}, [&](const VarHandle& l) {
Tensor e = Compute("scale", {24}, [&](const VarHandle& l) {
return b.load(0, 0, l) * d.load(l);
});
@ -1547,7 +1539,8 @@ TEST(Reductions, ReductionCacheConsumerAccess) {
l.cacheAccesses(d.buf(), "sum_local", e_loop);
l.prepareForCodegen();
StmtPtr result = IRSimplifier::simplify(l.root_stmt());
StmtPtr result =
LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
SimpleIREvaluator cg(result, {a, b, e});
std::ostringstream oss;
@ -1555,10 +1548,10 @@ TEST(Reductions, ReductionCacheConsumerAccess) {
const std::string& expected_ir =
R"IR(
#CHECK: Alias(sum_local,scale);
#CHECK: sum[l1] = (sum[l1]) + (scale[
#CHECK: for (int i = 0; i < 4
#CHECK: sum_local[i] = sum[i + 4 * l_outer];
#CHECK: scale_1[l_inner + 4 * l_outer] = (b[l_inner + 4 * l_outer]) * (sum_local[l_inner]);
#CHECK: sum[i_1] = (sum[i_1]) + (scale[
#CHECK: for (int j_2 = 0; j_2 < 4
#CHECK: sum_local[j_2] = sum[j_2 + 4 * i_2];
#CHECK: scale_1[j_3 + 4 * i_2] = (b[j_3 + 4 * i_2]) * (sum_local[j_3]);
)IR";
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
}
@ -1569,13 +1562,13 @@ TEST(Reductions, ReductionSplitCacheConsumerAccess) {
Tensor c = Compute(
"scale",
{{24, "l2"}, {32, "n1"}, {12, "m1"}},
{24, 32, 12},
[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
return b.load(l, n, m) * a.load(l, n, m);
});
Tensor d = Reduce("sum", {{24, "l1"}}, Sum(), c, {{32, "n1"}, {12, "m1"}});
Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12});
Tensor e = Compute("scale", {{24, "l"}}, [&](const VarHandle& l) {
Tensor e = Compute("scale", {24}, [&](const VarHandle& l) {
return b.load(0, 0, l) * d.load(l);
});
@ -1593,7 +1586,8 @@ TEST(Reductions, ReductionSplitCacheConsumerAccess) {
l.cacheAccesses(d.buf(), "sum_local", inner);
l.prepareForCodegen();
StmtPtr result = IRSimplifier::simplify(l.root_stmt());
StmtPtr result =
LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
SimpleIREvaluator cg(result, {a, b, e});
// reduction changes but cache does not.
@ -1602,10 +1596,12 @@ TEST(Reductions, ReductionSplitCacheConsumerAccess) {
const std::string& expected_ir =
R"IR(
#CHECK: Alias(sum_local,scale);
#CHECK: sum[l1_inner + 4 * l1_outer] = (sum[l1_inner + 4 * l1_outer]) + (scale[((m1_1 + 12 * n1_1) + 1536 * l1_outer) + 384 * l1_inner]);
#CHECK: for (int i = 0; i < 4
#CHECK: sum_local[i] = sum[i + 4 * l_outer];
#CHECK: scale_1[l_inner + 4 * l_outer] = (b[l_inner + 4 * l_outer]) * (sum_local[l_inner]);
#CHECK: sum[j_1 + 4 * i_1] = (sum[j_1 + 4 * i_1]) + (scale[((l + 12 * k_1) + 1536 * i_1) + 384 * j_1]);
#CHECK: for (int i_2 = 0; i_2 < 6
#CHECK: for (int j_2 = 0; j_2 < 4
#CHECK: sum_local[j_2] = sum[j_2 + 4 * i_2];
#CHECK: for (int j_3 = 0; j_3 < 4
#CHECK: scale_1[j_3 + 4 * i_2] = (b[j_3 + 4 * i_2]) * (sum_local[j_3]);
)IR";
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
}
@ -1616,13 +1612,13 @@ TEST(Reductions, ReductionReorderCacheConsumerAccess) {
Tensor c = Compute(
"scale",
{{24, "l2"}, {32, "n1"}, {12, "m1"}},
{24, 32, 12},
[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
return b.load(l, n, m) * a.load(l, n, m);
});
Tensor d = Reduce("sum", {{24, "l1"}}, Sum(), c, {{32, "n1"}, {12, "m1"}});
Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12});
Tensor e = Compute("scale", {{24, "l"}}, [&](const VarHandle& l) {
Tensor e = Compute("scale", {24}, [&](const VarHandle& l) {
return b.load(0, 0, l) * d.load(l);
});
@ -1641,7 +1637,8 @@ TEST(Reductions, ReductionReorderCacheConsumerAccess) {
l.cacheAccesses(d.buf(), "sum_local", inner);
l.prepareForCodegen();
StmtPtr result = IRSimplifier::simplify(l.root_stmt());
StmtPtr result =
LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
SimpleIREvaluator cg(result, {a, b, e});
// neither reduction body not cache changes.
@ -1649,10 +1646,12 @@ TEST(Reductions, ReductionReorderCacheConsumerAccess) {
oss << *cg.stmt();
const std::string& expected_ir =
R"IR(
#CHECK: sum[l1] = (sum[l1]) + (scale[(m1_1 + 12 * n1_1) + 384 * l1]);
#CHECK: for (int i = 0; i < 4
#CHECK: sum_local[i] = sum[i + 4 * l_outer];
#CHECK: scale_1[l_inner + 4 * l_outer] = (b[l_inner + 4 * l_outer]) * (sum_local[l_inner]);
#CHECK: sum[j_1] = (sum[j_1]) + (scale[(k_1 + 12 * i_2) + 384 * j_1]);
#CHECK: for (int i_3 = 0; i_3 < 6;
#CHECK: for (int j_2 = 0; j_2 < 4;
#CHECK: sum_local[j_2] = sum[j_2 + 4 * i_3];
#CHECK: for (int j_3 = 0; j_3 < 4;
#CHECK: scale_1[j_3 + 4 * i_3] = (b[j_3 + 4 * i_3]) * (sum_local[j_3]);
)IR";
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
}
@ -1673,7 +1672,7 @@ TEST(Reductions, ReductionRfactorCacheTempOuter) {
std::vector<float> out(1, -1.f);
Tensor c = Reduce("sum", {}, Sum(), b, {{m, "a"}, {n, "b"}, {k, "c"}});
Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k});
LoopNest loop({c});
std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
@ -1693,7 +1692,7 @@ TEST(Reductions, ReductionRfactorCacheTempOuter) {
LoopNest::cacheAccesses(rfac_buf, "tmp", all_loops[1][1]);
loop.simplify();
loop.prepareForCodegen();
StmtPtr s = loop.root_stmt();
StmtPtr s = LoopNest::sanitizeNames(loop.root_stmt());
SimpleIREvaluator cg(s, {b, c, m, n, k});
std::ostringstream oss;
@ -1702,17 +1701,17 @@ TEST(Reductions, ReductionRfactorCacheTempOuter) {
R"IR(
#CHECK: Allocate(sum_rfac); // dtype=float, dims=[n]
#CHECK: Allocate(tmp); // dtype=float, dims=[n]
#CHECK: for (int a = 0; a < m
#CHECK: for (int i = 0; i < n
#CHECK: tmp[i] = 0
#CHECK: for (int i_1 = 0; i_1 < m
#CHECK: for (int j = 0; j < n
#CHECK: tmp[j] = 0
#CHECK: }
#CHECK: for (int b = 0; b < n
#CHECK: for (int c
#CHECK: tmp[b] = (tmp[b]) + (B[
#CHECK: for (int j_1 = 0; j_1 < n
#CHECK: for (int k
#CHECK: tmp[j_1] = (tmp[j_1]) + (B[
#CHECK: }
#CHECK: }
#CHECK: for (int i = 0; i < n
#CHECK: sum_rfac[i] = (sum_rfac[i]) + (tmp[i]);
#CHECK: for (int j_2 = 0; j_2 < n
#CHECK: sum_rfac[j_2] = (sum_rfac[j_2]) + (tmp[j_2]);
#CHECK: }
#CHECK: Free(tmp);
#CHECK-NOT: tmp
@ -1739,7 +1738,7 @@ TEST(Reductions, ReductionRfactorCacheTempInner) {
std::vector<float> out(1, -1.f);
Tensor c = Reduce("sum", {}, Sum(), b, {{m, "a"}, {n, "b"}, {k, "c"}});
Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k});
LoopNest loop({c});
std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
auto c_body = loop.getAllWritesToBuf(c.buf())[1];
@ -1759,7 +1758,7 @@ TEST(Reductions, ReductionRfactorCacheTempInner) {
LoopNest::cacheAccesses(rfac_buf, "tmp", all_loops[1][2]);
loop.prepareForCodegen();
loop.simplify();
StmtPtr s = loop.root_stmt();
StmtPtr s = LoopNest::sanitizeNames(loop.root_stmt());
SimpleIREvaluator cg(s, {b, c, m, n, k});
std::ostringstream oss;
@ -1768,13 +1767,13 @@ TEST(Reductions, ReductionRfactorCacheTempInner) {
R"IR(
#CHECK: Allocate(sum_rfac); // dtype=float, dims=[n]
#CHECK: Allocate(tmp); // dtype=float, dims=[1]
#CHECK: for (int a = 0; a < m
#CHECK: for (int b = 0; b < n
#CHECK: for (int i_1 = 0; i_1 < m
#CHECK: for (int j = 0; j < n
#CHECK: tmp[0] = 0
#CHECK: for (int c
#CHECK: for (int k
#CHECK: tmp[0] = (tmp[0]) + (B[
#CHECK: }
#CHECK: sum_rfac[b] = (sum_rfac[b]) + (tmp[0]);
#CHECK: sum_rfac[j] = (sum_rfac[j]) + (tmp[0]);
#CHECK: Free(tmp);
#CHECK-NOT: tmp
)IR";
@ -1796,7 +1795,7 @@ TEST(Reductions, ReductionVectorize) {
BufHandle in("in", {8, 8}, kFloat);
Tensor tensor = Reduce("sum", {{8, "m"}}, Sum(), in, {{8, "n"}});
Tensor tensor = Reduce("sum", {8}, Sum(), in, {8});
LoopNest l_before({tensor});
LoopNest l(l_before);
l_before.prepareForCodegen();
@ -1806,15 +1805,15 @@ TEST(Reductions, ReductionVectorize) {
ASSERT_TRUE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[0]));
StmtPtr s = l.root_stmt();
s = IRSimplifier::simplify(s);
s = LoopNest::sanitizeNames(IRSimplifier::simplify(s));
std::ostringstream oss;
oss << *s;
const std::string& expected_ir =
R"IR(
#CHECK: sum[Ramp(0, 1, 8)] = Broadcast(0.f, 8);
#CHECK: for (int n = 0; n < 8; n++) {
#CHECK: sum[Ramp(0, 1, 8)] = ReduceOp((sum[Ramp(0, 1, 8)]) + (in[Ramp(n, 8, 8)]), reduce_args={n});
#CHECK: for (int i = 0; i < 8; i++) {
#CHECK: sum[Ramp(0, 1, 8)] = ReduceOp((sum[Ramp(0, 1, 8)]) + (in[Ramp(i, 8, 8)]), reduce_args={i});
#CHECK: }
)IR";
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
@ -1832,7 +1831,7 @@ TEST(Reductions, ReductionVectorize) {
TEST(Reductions, ReductionVectorizeInner) {
BufHandle in("in", {8, 8}, kFloat);
Tensor tensor = Reduce("sum", {{8, "m"}}, Sum(), in, {{8, "n"}});
Tensor tensor = Reduce("sum", {8}, Sum(), in, {8});
LoopNest l({tensor});
ASSERT_FALSE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[1]));
@ -1850,7 +1849,7 @@ TEST(Reductions, ReductionVectorizeRfactor) {
BufHandle in("in", {8, 8}, kFloat);
Tensor tensor = Reduce("sum", {}, Sum(), in, {{8, "m"}, {8, "n"}});
Tensor tensor = Reduce("sum", {}, Sum(), in, {8, 8});
LoopNest l_before({tensor});
LoopNest l(l_before);
@ -1875,21 +1874,21 @@ TEST(Reductions, ReductionVectorizeRfactor) {
ASSERT_TRUE(LoopNest::vectorize(rfac_loops[1][0]));
l.simplify();
StmtPtr s = l.root_stmt();
StmtPtr s = LoopNest::sanitizeNames(l.root_stmt());
std::ostringstream oss;
oss << *s;
const std::string& expected_ir =
R"IR(
#CHECK: sum = 0.f;
#CHECK: for (int n = 0; n < 8; n++) {
#CHECK: sum_rfac[n] = 0.f;
#CHECK: for (int i = 0; i < 8; i++) {
#CHECK: sum_rfac[i] = 0.f;
#CHECK: }
#CHECK: for (int m = 0; m < 8; m++) {
#CHECK: sum_rfac[Ramp(0, 1, 8)] = ReduceOp((sum_rfac[Ramp(0, 1, 8)]) + (in[Ramp(8 * m, 1, 8)]), reduce_args={m});
#CHECK: for (int i_1 = 0; i_1 < 8; i_1++) {
#CHECK: sum_rfac[Ramp(0, 1, 8)] = ReduceOp((sum_rfac[Ramp(0, 1, 8)]) + (in[Ramp(8 * i_1, 1, 8)]), reduce_args={i_1});
#CHECK: }
#CHECK: for (int n = 0; n < 8; n++) {
#CHECK: sum = ReduceOp((sum) + (sum_rfac[n]), reduce_args={n});
#CHECK: for (int i_2 = 0; i_2 < 8; i_2++) {
#CHECK: sum = ReduceOp((sum) + (sum_rfac[i_2]), reduce_args={i_2});
#CHECK: }
)IR";
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
@ -1910,22 +1909,22 @@ TEST(Reductions, InitFunction) {
BufHandle B("B", {N}, kFloat);
Tensor C = Reduce(
"C",
{{N, "n"}},
{N},
Sum(),
[&](const std::vector<VarHandle>& v) { return B.load(v[0]); },
[&](const std::vector<VarHandle>& v) { return A.load(v[1], v[0]); },
{{M, "m"}});
{M});
LoopNest nest({C});
nest.prepareForCodegen();
StmtPtr s = IRSimplifier::simplify(nest.root_stmt());
StmtPtr s = LoopNest::sanitizeNames(IRSimplifier::simplify(nest.root_stmt()));
std::ostringstream oss;
oss << *s << "\n";
const std::string& expected_ir =
R"IR(
#CHECK: for (int n = 0; n < 16; n++) {
#CHECK: C[n] = B[n];
#CHECK: for (int m = 0; m < 32; m++) {
#CHECK: C[n] = (C[n]) + (A[n + 16 * m]);
#CHECK: for (int i = 0; i < 16; i++) {
#CHECK: C[i] = B[i];
#CHECK: for (int j = 0; j < 32; j++) {
#CHECK: C[i] = (C[i]) + (A[i + 16 * j]);
#CHECK: }
#CHECK: }
)IR";

View File

@ -3858,26 +3858,25 @@ TEST(Simplify, SimplifyForCleansUp) {
BufHandle a("a", {1, 12, 1}, kFloat);
VarHandle x("x", kInt);
Tensor b = Compute(
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
"x",
{{1, "i"}, {12, "m"}, {1, "n"}},
{1, 12, 1},
[](const VarHandle& i, const VarHandle& m, const VarHandle& n) {
return i + m + n;
});
LoopNest l({b});
l.prepareForCodegen();
StmtPtr body = l.root_stmt();
StmtPtr body = LoopNest::sanitizeNames(l.root_stmt());
StmtPtr simplified = IRSimplifier::simplify(body);
BlockPtr block = to<Block>(simplified);
IS_NODE_WITH_NAME(For, block->front(), for_);
// for is over "m".
IS_VAR_WITH_NAME(for_->var(), "m");
IS_VAR_WITH_NAME(for_->var(), "j");
// x[m] = m;
IS_NODE_WITH_NAME(Store, for_->body()->front(), store);
IS_VAR_WITH_NAME(store->flat_index(), "m");
IS_VAR_WITH_NAME(store->value(), "m");
IS_VAR_WITH_NAME(store->flat_index(), "j");
IS_VAR_WITH_NAME(store->value(), "j");
}
}

View File

@ -186,10 +186,10 @@ int main(int argc, char* argv[]) {
// structure is simply a pair of a buffer that was created to represent the
// result of the computation (BufPtr) and a statement representing the
// computation itself (StmtPtr).
Tensor C = Compute(
"C",
{{64, "i"}, {32, "j"}},
[&](const VarHandle& i, const VarHandle& j) { return i * j; });
Tensor C =
Compute("C", {64, 32}, [&](const VarHandle& i, const VarHandle& j) {
return i * j;
});
std::cout << "Stmt produced by 'Compute' API: " << std::endl
<< *C.stmt() << std::endl;
// Prints:
@ -209,7 +209,7 @@ int main(int argc, char* argv[]) {
{},
Sum(),
[&](const VarHandle& i, const VarHandle& j) { return C.load(i, j); },
{{64, "i"}, {32, "j"}});
{64, 32});
std::cout << "Stmt produced by 'Reduce' API: " << std::endl
<< *D.stmt() << std::endl;
}
@ -223,15 +223,13 @@ int main(int argc, char* argv[]) {
// Let's look at a couple of transformations that are used in NNC. We will
// begin with constructing a Block statement like we did before.
Tensor C = Compute(
"C",
{{64, "i"}, {32, "j"}},
[&](const VarHandle& i, const VarHandle& j) { return i * (j + 1); });
Tensor C =
Compute("C", {64, 32}, [&](const VarHandle& i, const VarHandle& j) {
return i * (j + 1);
});
BufHandle c_buf(C.buf());
Tensor D = Compute(
"D",
{{64, "i"}, {32, "j"}},
[&](const VarHandle& i, const VarHandle& j) {
Tensor D =
Compute("D", {64, 32}, [&](const VarHandle& i, const VarHandle& j) {
return c_buf.load(i, j) - i;
});
StmtPtr block = Block::make({C.stmt(), D.stmt()});
@ -353,10 +351,8 @@ int main(int argc, char* argv[]) {
// Let's start by constructing a simple computation for us to work with:
BufHandle A("A", {64, 32}, kInt);
BufHandle B("B", {64, 32}, kInt);
Tensor X = Compute(
"X",
{{64, "i"}, {32, "j"}},
[&](const VarHandle& i, const VarHandle& j) {
Tensor X =
Compute("X", {64, 32}, [&](const VarHandle& i, const VarHandle& j) {
return A.load(i, j) + B.load(i, j);
});

View File

@ -349,19 +349,13 @@ graph(%x : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu)):
graph = torch._C.parse_ir(graph_str)
def my_custom_lowering(inputs, out_shape, out_type, device):
def get_dim_args(dims):
dim_args = []
for dim in dims:
dim_args.append(te.DimArg(dim, "i" + str(len(dim_args))))
return dim_args
def compute(idxs):
load = inputs[0].as_buf().load(idxs)
return te.ifThenElse(
te.ExprHandle.isnan(load), te.ExprHandle.float(0.0), load
)
return te.Compute2("custom_nan_to_num", get_dim_args(out_shape), compute)
return te.Compute2("custom_nan_to_num", out_shape, compute)
kernel = te.TensorExprKernel(graph, {"aten::nan_to_num": my_custom_lowering})
res1 = kernel.run((x,))

View File

@ -1,33 +0,0 @@
#pragma once
#include <torch/csrc/jit/tensorexpr/expr.h>
namespace torch {
namespace jit {
namespace tensorexpr {
// A helper structure to store the arguments to specify dimensions. In the
// Compute arguments for dim_args, all of the following is supported. For
// example:
// dim_args: {1, 2, 3, 4}
// dim_args: {{1, "x"}, {2, "y"}, {3, "z"}}
// dim_args: {1, 2, {3, "x"}}
class DimArg {
public:
// Intentionally leave out explicit to allow implicit conversions.
DimArg(const ExprHandle& dim) : dim_(dim) {}
DimArg(const ExprHandle& dim, std::string name_hint)
: dim_(dim), name_hint_(std::move(name_hint)) {}
const ExprHandle& dim() const {
return dim_;
}
const std::string& name_hint() const {
return name_hint_;
}
private:
ExprHandle dim_;
std::string name_hint_;
};
} // namespace tensorexpr
} // namespace jit
} // namespace torch

View File

@ -402,7 +402,7 @@ void IRPrinter::visit(ReduceOpPtr v) {
if (!first) {
os() << ", ";
}
os() << d->name_hint();
os() << *d;
first = false;
}
os() << "})";

View File

@ -1033,12 +1033,7 @@ Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) {
// if the input isn't contiguous or is an output,
// write strided input into contiguous buffer that is
// then used in all further compute
std::vector<DimArg> inputTensorDims;
auto size_handles = sizesFromSymbolicShape(tt->symbolic_sizes());
for (size_t i = 0; i < size_handles.size(); i++) {
auto size = size_handles[i];
inputTensorDims.emplace_back(DimArg(size, "i" + c10::to_string(i)));
}
auto inputTensorStrides = getInputStrides(input, size_handles);
ExprHandle flat_size = 1;
for (size_t i = 0; i < size_handles.size(); ++i) {
@ -1057,7 +1052,7 @@ Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) {
result = Compute(
"input" + c10::to_string(bufs_.size() + 1),
inputTensorDims,
size_handles,
[&](const std::vector<VarHandle>& axes) {
ExprHandle idx = 0;
for (size_t i = 0; i < axes.size(); i++) {
@ -1144,11 +1139,10 @@ Tensor TensorExprKernel::convertSymbolicOutputToCorrectStrides(
// for stride in strides_from_largest_to_smallest:
// cur_idx = absolute // stride
// absolute = absolute % stride
auto dims = c10::fmap<DimArg>(sizes);
std::vector<ExprPtr> default_strides = make_contiguous_strides(sizes);
auto zero = LongImm::make(0);
return Compute(
"output_1", dims, [&](const std::vector<VarHandle>& axes_input) {
"output_1", sizes, [&](const std::vector<VarHandle>& axes_input) {
std::vector<ExprHandle> axes(axes_input.begin(), axes_input.end());
auto absolute_position = ExprHandle(immLike(axes[0], 0));
for (size_t i = 0; i < axes.size(); ++i) {
@ -1191,7 +1185,6 @@ Tensor TensorExprKernel::convertSymbolicOutputToCorrectStrides(
tensorOutputStrideDesc_[v->offset()] ==
torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST);
auto sizes = sizesFromSymbolicShape(tt->symbolic_sizes());
auto dims = c10::fmap<DimArg>(sizes);
auto strides = make_channels_last_strides(sizes);
// For a tensor with dimensions N C H W, channels last
// format will is in format N H W C,
@ -1243,7 +1236,7 @@ Tensor TensorExprKernel::convertStaticShapeOutputToCorrectStrides(
return Tensor(buf, nullptr);
}
auto dims = c10::fmap<DimArg>(sizesForValue(v));
auto dims = sizesForValue(v);
auto zero = LongImm::make(0);
std::vector<size_t> sorted_stride_indices = reverse_sort_indices(strides);

View File

@ -198,7 +198,6 @@ class TORCH_API TensorExprKernel {
void genInputDebugNames();
void runKernel(Stack& stack);
std::vector<DimArg> dimsFromSizes(const std::vector<ExprHandle>& sizes);
std::vector<ExprHandle> sizesForValue(const torch::jit::Value* v);
// These functions broadcast shape and also store a `hasBroadcast_` variable.

View File

@ -37,11 +37,13 @@ namespace tensorexpr {
LoopNest::LoopNest(const LoopNest& other)
: root_stmt_(Stmt::clone(other.root_stmt_)),
output_bufs_(other.output_bufs_) {
GRAPH_DEBUG("Origin Stmt in LoopNest:\n", std::to_string(root_stmt_));
verify(root_stmt_);
}
LoopNest::LoopNest(StmtPtr stmt, std::unordered_set<BufPtr> output_bufs)
: root_stmt_(stmt), output_bufs_(std::move(output_bufs)) {
GRAPH_DEBUG("Origin Stmt in LoopNest:\n", std::to_string(root_stmt_));
verify(root_stmt_);
}
@ -50,12 +52,14 @@ LoopNest::LoopNest(
const std::vector<Tensor>& output_tensors,
const std::vector<Tensor>& tensors_to_compute) {
initialize(output_tensors, tensors_to_compute);
GRAPH_DEBUG("Origin Stmt in LoopNest:\n", std::to_string(root_stmt_));
verify(root_stmt_);
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
LoopNest::LoopNest(const std::vector<Tensor>& output_tensors) {
initialize(output_tensors, output_tensors);
GRAPH_DEBUG("Origin Stmt in LoopNest:\n", std::to_string(root_stmt_));
verify(root_stmt_);
}

View File

@ -953,9 +953,7 @@ int nnc_lowerings_lazy_registration() {
auto const& shape =
broadcastShapes(valueShape(inputs[0]), valueShape(inputs[1]));
return Compute(
"aten_remainder",
c10::fmap<DimArg>(shape),
[&](const std::vector<VarHandle>& axes) {
"aten_remainder", shape, [&](const std::vector<VarHandle>& axes) {
std::vector<ExprHandle> indices(axes.begin(), axes.end());
std::vector<ExprHandle> exprInputs = {
tensorOrConstant(inputs[0], indices),
@ -1454,7 +1452,7 @@ int nnc_lowerings_lazy_registration() {
// at::Device device) {
// return Compute(
// "aten_slice",
// c10::fmap<DimArg>(outputShape),
// outputShape,
// [&](const std::vector<VarHandle>& axes) {
// int64_t dim =
// at::maybe_wrap_dim(c10::get<int64_t>(inputs[1]),
@ -1475,7 +1473,7 @@ int nnc_lowerings_lazy_registration() {
at::Device device) {
return Compute(
"aten_unsqueeze",
c10::fmap<DimArg>(outputShape),
outputShape,
[&](const std::vector<VarHandle>& axes) {
int64_t dim = c10::get<int64_t>(inputs[1]);
if (dim < 0) {
@ -1525,7 +1523,7 @@ int nnc_lowerings_lazy_registration() {
if (A.ndim() == 0) {
auto tensor = Compute(
"aten_permute",
c10::fmap<DimArg>(outputShape),
outputShape,
[&](const std::vector<VarHandle>& axes) {
std::vector<ExprHandle> empty_indices;
return A.load(empty_indices);
@ -1539,7 +1537,7 @@ int nnc_lowerings_lazy_registration() {
auto permute_dims = c10::get<IntList>(inputs[1]);
auto tensor = Compute(
"aten_permute",
c10::fmap<DimArg>(outputShape),
outputShape,
[&](const std::vector<VarHandle>& axes) {
std::vector<VarHandle> new_axes;
new_axes.resize(axes.size());

View File

@ -49,7 +49,7 @@ Tensor conv2d_depthwise_static(
Tensor conv = Reduce(
"conv2d_depthwise",
{{N, "n"}, {K, "k"}, {OH, "oh"}, {OW, "ow"}},
{N, K, OH, OW},
Sum(),
[&](const std::vector<VarHandle>& v) { return init_func(v); },
[&](const std::vector<VarHandle>& v) {
@ -70,7 +70,7 @@ Tensor conv2d_depthwise_static(
input.load(n, k, oh * stride - pad + r, ow * stride - pad + s));
return in * weight.load(k, c, r, s);
},
{{C / groups, "c"}, {R, "r"}, {S, "s"}});
{C / groups, R, S});
LoopNest nest({conv});
@ -120,7 +120,7 @@ Tensor conv2d_depthwise_dynamic(
return Reduce(
"conv2d_depthwise",
{{N, "n"}, {K, "k"}, {OH, "oh"}, {OW, "ow"}},
{N, K, OH, OW},
Sum(),
[&](const std::vector<VarHandle>& v) { return init_func(v); },
[&](const std::vector<VarHandle>& v) {
@ -141,7 +141,7 @@ Tensor conv2d_depthwise_dynamic(
input.load(n, k, oh * stride - pad + r, ow * stride - pad + s));
return in * weight.load(k, c, r, s);
},
{{C / groups, "c"}, {R, "r"}, {S, "s"}});
{C / groups, R, S});
}
} // namespace

View File

@ -38,12 +38,12 @@ Tensor computeMatmul(
if (total_size && total_size->value() < 1000) {
return Reduce(
"nnc_matmul",
{{size_a[0], "M"}, {size_b[1], "N"}},
{size_a[0], size_b[1]},
Sum(),
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
return Load::make(a, {m, k}) * Load::make(b, {k, n});
},
{{size_a[1], "K"}});
{size_a[1]});
} else {
return Tensor(
ResultBuf.node(),

View File

@ -324,7 +324,7 @@ Tensor computeChunk(
at::Device device) {
return Compute(
"prim_constantchunk",
c10::fmap<DimArg>(outputShape),
outputShape,
[inputs](const std::vector<VarHandle>& axes) {
const auto& b = c10::get<BufHandle>(inputs[0]);
int64_t chunkIdx = c10::get<int64_t>(inputs[1]);
@ -359,9 +359,7 @@ Tensor computeTranspose(
// Trivial case of 0-dim and 1-dim tensors: transpose is just a copy
if (A.ndim() <= 1) {
return Compute(
"aten_transpose",
c10::fmap<DimArg>(outputShape),
[&](std::vector<VarHandle> axes) {
"aten_transpose", outputShape, [&](std::vector<VarHandle> axes) {
TORCH_INTERNAL_ASSERT(
axes.size() <= 1,
buildErrorMessage("Invalid axes size in transpose"));
@ -372,9 +370,7 @@ Tensor computeTranspose(
auto start_dim = at::maybe_wrap_dim(c10::get<int64_t>(inputs[1]), A.ndim());
auto to_dim = at::maybe_wrap_dim(c10::get<int64_t>(inputs[2]), A.ndim());
return Compute(
"aten_transpose",
c10::fmap<DimArg>(outputShape),
[&](std::vector<VarHandle> axes) {
"aten_transpose", outputShape, [&](std::vector<VarHandle> axes) {
std::swap(axes[start_dim], axes[to_dim]);
return A.load(axes);
});
@ -387,9 +383,7 @@ Tensor computeExpand(
at::Device device) {
auto A = c10::get<BufHandle>(inputs[0]);
return Compute(
"aten_expand",
c10::fmap<DimArg>(outputShape),
[&](const std::vector<VarHandle>& axes) {
"aten_expand", outputShape, [&](const std::vector<VarHandle>& axes) {
std::vector<ExprHandle> indices(axes.begin(), axes.end());
return broadcast(A, indices);
});
@ -403,17 +397,13 @@ Tensor computeReshape(
auto A = c10::get<BufHandle>(inputs[0]);
if (A.ndim() == 0) {
return Compute(
"aten_view",
c10::fmap<DimArg>(outputShape),
[&](const std::vector<VarHandle>& axes) {
"aten_view", outputShape, [&](const std::vector<VarHandle>& axes) {
std::vector<ExprHandle> empty_indices;
return A.load(empty_indices);
});
}
return Compute(
"aten_reshape",
c10::fmap<DimArg>(outputShape),
[&](const std::vector<VarHandle>& axes) {
"aten_reshape", outputShape, [&](const std::vector<VarHandle>& axes) {
std::vector<VarHandle> new_axes;
assert(outputShape.size() == axes.size());
/*
@ -608,9 +598,7 @@ Tensor computeCat(
ScalarType highType = catInfo.first;
std::vector<BufHandle> nonEmptyInputs = catInfo.second;
return Compute(
"aten_cat",
c10::fmap<DimArg>(outputShape),
[&](const std::vector<VarHandle>& axes) {
"aten_cat", outputShape, [&](const std::vector<VarHandle>& axes) {
if (nonEmptyInputs.size() == 0) {
return ExprHandle(0);
}

View File

@ -22,9 +22,7 @@ Tensor computeBatchNorm(
}
return Compute(
"aten_batch_norm",
c10::fmap<DimArg>(outputShape),
[&](const std::vector<VarHandle>& axes) {
"aten_batch_norm", outputShape, [&](const std::vector<VarHandle>& axes) {
TORCH_INTERNAL_ASSERT(axes.size() >= 2);
// axes: N, C, H, W
std::vector<ExprHandle> indices(axes.begin(), axes.end());

View File

@ -10,16 +10,15 @@ using namespace torch::jit::tensorexpr;
Tensor computeSign(
const std::vector<ArgValue>& inputValues,
const std::vector<ExprHandle>& outputShape) {
return Compute(
"aten_sign", c10::fmap<DimArg>(outputShape), [&](ParameterList& axes) {
std::vector<ExprHandle> indices(axes.begin(), axes.end());
std::vector<ExprHandle> inputs = {
tensorOrConstant(inputValues[0], indices)};
auto inp = inputs[0];
auto zero = ExprHandle(immLike(inp, 0.0f));
auto res = (zero < inp) - (inp < zero);
return promoteToDtype(res, inp.dtype().scalar_type());
});
return Compute("aten_sign", outputShape, [&](ParameterList& axes) {
std::vector<ExprHandle> indices(axes.begin(), axes.end());
std::vector<ExprHandle> inputs = {
tensorOrConstant(inputValues[0], indices)};
auto inp = inputs[0];
auto zero = ExprHandle(immLike(inp, 0.0f));
auto res = (zero < inp) - (inp < zero);
return promoteToDtype(res, inp.dtype().scalar_type());
});
}
Tensor computeOneOperand(
@ -31,7 +30,7 @@ Tensor computeOneOperand(
const int checkParamTypes) {
return Compute(
name,
c10::fmap<DimArg>(outputShape),
outputShape,
[inputValues, outputType, innerExpr, checkParamTypes](
const std::vector<VarHandle>& axes) {
std::vector<ExprHandle> indices(axes.begin(), axes.end());
@ -52,7 +51,7 @@ Tensor computeTwoOperand(
innerExpr) {
return Compute(
name,
c10::fmap<DimArg>(outputShape),
outputShape,
[inputValues, outputType, innerExpr](const std::vector<VarHandle>& axes) {
std::vector<ExprHandle> indices(axes.begin(), axes.end());
std::vector<ExprHandle> inputs = {
@ -75,7 +74,7 @@ Tensor computeTwoOperandWithAlpha(
innerExpr) {
return Compute(
name,
c10::fmap<DimArg>(outputShape),
outputShape,
[inputValues, outputType, innerExpr](const std::vector<VarHandle>& axes) {
std::vector<ExprHandle> indices(axes.begin(), axes.end());
std::vector<ExprHandle> inputs = {
@ -100,7 +99,7 @@ Tensor computeConditionWithTwoOperand(
innerExpr) {
return Compute(
name,
c10::fmap<DimArg>(outputShape),
outputShape,
[inputValues, outputType, innerExpr](const std::vector<VarHandle>& axes) {
std::vector<ExprHandle> indices(axes.begin(), axes.end());
std::vector<ExprHandle> inputs = {
@ -128,7 +127,7 @@ Tensor computeThreeOperand(
bool promote_inputs) {
return Compute(
name,
c10::fmap<DimArg>(outputShape),
outputShape,
[inputValues, outputType, innerExpr, promote_inputs](
const std::vector<VarHandle>& axes) {
std::vector<ExprHandle> indices(axes.begin(), axes.end());
@ -157,7 +156,7 @@ Tensor computeFourOperand(
const ExprHandle&)>& innerExpr) {
return Compute(
name,
c10::fmap<DimArg>(outputShape),
outputShape,
[inputValues, outputType, innerExpr](const std::vector<VarHandle>& axes) {
std::vector<ExprHandle> indices(axes.begin(), axes.end());
std::vector<ExprHandle> inputs = {

View File

@ -727,9 +727,7 @@ Tensor computeUpsampleNearest2d(
auto input_height = ExprHandle(A.dim(2));
auto input_width = ExprHandle(A.dim(3));
std::vector<ExprHandle> dims;
std::vector<VarHandle> args;
unpack_dim_args(c10::fmap<DimArg>(outputShape), &dims, &args);
std::vector<VarHandle> args = create_index_vars(outputShape);
// Handle separately when scale is specified? as in 'scalar_t
// compute_scales_value' in UpSample.h
auto scale_h =

View File

@ -53,12 +53,12 @@ Tensor computeSum(
std::iota(axes.begin(), axes.end(), 0);
}
// Axes go into reduction dimensions.
std::vector<DimArg> reductionDims;
std::vector<ExprHandle> reductionDims;
reductionDims.reserve(rank);
for (size_t axis : axes) {
reductionDims.emplace_back(sizes[axis]);
}
std::vector<DimArg> outputDims;
std::vector<ExprHandle> outputDims;
// Output dimensions are the complement of axes. When keepdim is set, a
// one-sized dimension is inserted for each axis.
for (size_t dim = 0; dim < rank; ++dim) {

View File

@ -40,7 +40,6 @@ Tensor computeSoftmax(
// - Final loop computes the log_softmax for every element in v.
TORCH_INTERNAL_ASSERT(inputs.size() == 3);
auto output_dims = c10::fmap<DimArg>(outputShape);
// We do not handle None for dims (input 1) because that is supposed to
// be deprecated.
@ -48,10 +47,10 @@ Tensor computeSoftmax(
int64_t rank = valueShape(inputs[0]).size();
size_t softmax_dim =
normalizeAndCheckIndex(c10::get<int64_t>(inputs[1]), rank);
std::vector<DimArg> non_softmax_dims;
for (size_t i = 0; i < output_dims.size(); ++i) {
std::vector<ExprHandle> non_softmax_dims;
for (size_t i = 0; i < outputShape.size(); ++i) {
if (i != softmax_dim) {
non_softmax_dims.push_back(output_dims[i]);
non_softmax_dims.push_back(outputShape[i]);
}
}
@ -108,9 +107,9 @@ Tensor computeSoftmax(
return tensorOrConstant(
inputs[0], move_softmax_dim_index_to_pos(indices));
},
{output_dims[softmax_dim]});
{outputShape[softmax_dim]});
auto e =
Compute("aten_softmax_exp", output_dims, [&](ParameterList& indices) {
Compute("aten_softmax_exp", outputShape, [&](ParameterList& indices) {
auto inp = tensorOrConstant(
inputs[0], convert_indices_to_expr_handle(indices));
return exp(inp - max.load(remove_softmax_dim_index(indices)));
@ -122,10 +121,10 @@ Tensor computeSoftmax(
[&](ParameterList& indices) {
return e.load(move_softmax_dim_index_to_pos(indices));
},
{output_dims[softmax_dim]});
{outputShape[softmax_dim]});
if (!log_softmax) {
auto result =
Compute("aten_softmax", output_dims, [&](ParameterList& indices) {
Compute("aten_softmax", outputShape, [&](ParameterList& indices) {
return e.load(indices) / sum.load(remove_softmax_dim_index(indices));
});
return Tensor(
@ -139,7 +138,7 @@ Tensor computeSoftmax(
return log(sum.load(indices));
});
auto result =
Compute("aten_log_softmax", output_dims, [&](ParameterList& indices) {
Compute("aten_log_softmax", outputShape, [&](ParameterList& indices) {
auto inp = tensorOrConstant(
inputs[0], convert_indices_to_expr_handle(indices));
auto non_softmax_indices = remove_softmax_dim_index(indices);

View File

@ -1,6 +1,5 @@
#pragma once
#include <torch/csrc/jit/tensorexpr/dim_arg.h>
#include <torch/csrc/jit/tensorexpr/expr.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_printer.h>

View File

@ -2,7 +2,6 @@
#include <c10/util/Logging.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/tensorexpr/dim_arg.h>
#include <torch/csrc/jit/tensorexpr/reduction.h>
namespace torch {
@ -51,11 +50,9 @@ StmtPtr Tensor::constructStmt(
Tensor Compute(
const std::string& name,
const std::vector<DimArg>& dim_args,
const std::vector<ExprHandle>& dims,
const std::function<ExprHandle(const std::vector<VarHandle>&)>& body_func) {
std::vector<ExprHandle> dims;
std::vector<VarHandle> args;
unpack_dim_args(dim_args, &dims, &args);
std::vector<VarHandle> args = create_index_vars(dims);
ExprHandle body = body_func(args);
BufHandle buf = Buf::make(name, dims, body.dtype());
return Tensor(buf, args, body);
@ -63,15 +60,13 @@ Tensor Compute(
Tensor Compute(
const std::string& name,
const std::vector<DimArg>& dim_args,
const std::vector<ExprHandle>& dims,
const std::function<ExprHandle(const VarHandle&)>& body_func) {
if (dim_args.size() != 1) {
if (dims.size() != 1) {
throw malformed_input("mismatch between body and arg size (1)");
}
std::vector<ExprHandle> dims;
std::vector<VarHandle> args;
unpack_dim_args(dim_args, &dims, &args);
std::vector<VarHandle> args = create_index_vars(dims);
ExprHandle body = body_func(args[0]);
BufHandle buf = Buf::make(name, dims, body.dtype());
return Tensor(buf, args, body);
@ -79,15 +74,13 @@ Tensor Compute(
Tensor Compute(
const std::string& name,
const std::vector<DimArg>& dim_args,
const std::vector<ExprHandle>& dims,
const std::function<ExprHandle(const VarHandle&, const VarHandle&)>&
body_func) {
if (dim_args.size() != 2) {
if (dims.size() != 2) {
throw malformed_input("mismatch between body and arg size (2)");
}
std::vector<ExprHandle> dims;
std::vector<VarHandle> args;
unpack_dim_args(dim_args, &dims, &args);
std::vector<VarHandle> args = create_index_vars(dims);
ExprHandle body = body_func(args[0], args[1]);
BufHandle buf = Buf::make(name, dims, body.dtype());
return Tensor(buf, args, body);
@ -95,16 +88,14 @@ Tensor Compute(
Tensor Compute(
const std::string& name,
const std::vector<DimArg>& dim_args,
const std::vector<ExprHandle>& dims,
const std::function<
ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>&
body_func) {
if (dim_args.size() != 3) {
if (dims.size() != 3) {
throw malformed_input("mismatch between body and arg size (3)");
}
std::vector<ExprHandle> dims;
std::vector<VarHandle> args;
unpack_dim_args(dim_args, &dims, &args);
std::vector<VarHandle> args = create_index_vars(dims);
ExprHandle body = body_func(args[0], args[1], args[2]);
BufHandle buf = Buf::make(name, dims, body.dtype());
return Tensor(buf, args, body);
@ -112,18 +103,16 @@ Tensor Compute(
Tensor Compute(
const std::string& name,
const std::vector<DimArg>& dim_args,
const std::vector<ExprHandle>& dims,
const std::function<ExprHandle(
const VarHandle&,
const VarHandle&,
const VarHandle&,
const VarHandle&)>& body_func) {
if (dim_args.size() != 4) {
if (dims.size() != 4) {
throw malformed_input("mismatch between body and arg size (4)");
}
std::vector<ExprHandle> dims;
std::vector<VarHandle> args;
unpack_dim_args(dim_args, &dims, &args);
std::vector<VarHandle> args = create_index_vars(dims);
ExprHandle body = body_func(args[0], args[1], args[2], args[3]);
BufHandle buf = Buf::make(name, dims, body.dtype());
return Tensor(buf, args, body);
@ -131,30 +120,30 @@ Tensor Compute(
Tensor Reduce(
const std::string& name,
const std::vector<DimArg>& dim_args,
const std::vector<ExprHandle>& dims,
const Reducer& reducer,
const BufHandle& buffer,
const std::vector<DimArg>& reduce_args) {
const std::vector<ExprHandle>& reduce_dims) {
return Reduce(
name,
dim_args,
dims,
reducer,
[&](ParameterList& p) { return buffer.load(p); },
reduce_args);
reduce_dims);
}
Tensor Reduce(
const std::string& name,
const std::vector<DimArg>& dim_args,
const std::vector<ExprHandle>& dims,
const Reducer& reducer,
Tensor tensor,
const std::vector<DimArg>& reduce_args) {
const std::vector<ExprHandle>& reduce_dims) {
return Reduce(
name,
dim_args,
dims,
reducer,
[&](ParameterList& p) { return tensor.load(p); },
reduce_args);
reduce_dims);
}
} // namespace tensorexpr

View File

@ -4,7 +4,6 @@
#include <functional>
#include <vector>
#include <torch/csrc/jit/tensorexpr/dim_arg.h>
#include <torch/csrc/jit/tensorexpr/expr.h>
#include <torch/csrc/jit/tensorexpr/reduction.h>
@ -73,22 +72,22 @@ class TORCH_API Tensor {
TORCH_API Tensor Compute(
const std::string& func_name,
const std::vector<DimArg>& dim_args,
const std::vector<ExprHandle>& dims,
const std::function<ExprHandle(const VarHandle&)>& body_func);
TORCH_API Tensor Compute(
const std::string& func_name,
const std::vector<DimArg>& dim_args,
const std::vector<ExprHandle>& dims,
const std::function<ExprHandle(const VarHandle&, const VarHandle&)>&
body_func);
TORCH_API Tensor Compute(
const std::string& func_name,
const std::vector<DimArg>& dim_args,
const std::vector<ExprHandle>& dims,
const std::function<
ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>&
body_func);
TORCH_API Tensor Compute(
const std::string& func_name,
const std::vector<DimArg>& dim_args,
const std::vector<ExprHandle>& dims,
const std::function<ExprHandle(
const VarHandle&,
const VarHandle&,
@ -96,40 +95,31 @@ TORCH_API Tensor Compute(
const VarHandle&)>& body_func);
TORCH_API Tensor Compute(
const std::string& func_name,
const std::vector<DimArg>& dim_args,
const std::vector<ExprHandle>& dims,
const std::function<ExprHandle(const std::vector<VarHandle>&)>& body_func);
inline void unpack_dim_args(
const std::vector<DimArg>& dim_args,
std::vector<ExprHandle>* dims,
std::vector<VarHandle>* vars) {
dims->clear();
vars->clear();
for (const DimArg& dim_arg : dim_args) {
ExprHandle expr = dim_arg.dim();
dims->push_back(expr);
vars->push_back(VarHandle(alloc<Var>(
dim_arg.name_hint(),
expr.dtype().scalar_type() == ScalarType::Long ? kLong : kInt)));
inline std::vector<VarHandle> create_index_vars(
const std::vector<ExprHandle>& dims) {
std::vector<VarHandle> vars;
vars.reserve(dims.size());
for (const ExprHandle& dim : dims) {
vars.push_back(VarHandle(alloc<Var>(
"i", dim.dtype().scalar_type() == ScalarType::Long ? kLong : kInt)));
}
return vars;
}
// Handle reductions over a Reducer and a body_func which produces values.
template <typename InitFunc, typename BodyFunc>
Tensor Reduce(
const std::string& func_name,
const std::vector<DimArg>& dim_args,
const std::vector<ExprHandle>& dims,
const Reducer& reducer,
const InitFunc& init_func,
const BodyFunc& body_func,
const std::vector<DimArg>& reduce_args) {
std::vector<ExprHandle> dims;
std::vector<VarHandle> vars;
unpack_dim_args(dim_args, &dims, &vars);
std::vector<ExprHandle> reduce_dims;
std::vector<VarHandle> reduce_vars;
unpack_dim_args(reduce_args, &reduce_dims, &reduce_vars);
const std::vector<ExprHandle>& reduce_dims) {
std::vector<VarHandle> vars = create_index_vars(dims);
std::vector<VarHandle> reduce_vars = create_index_vars(reduce_dims);
// If reduce_vars is empty, then it's not a reduction, but rather a simple
// copy
@ -155,45 +145,45 @@ Tensor Reduce(
template <typename BodyFunc>
Tensor Reduce(
const std::string& func_name,
const std::vector<DimArg>& dim_args,
const std::vector<ExprHandle>& dims,
const Reducer& reducer,
const BodyFunc& body_func,
const std::vector<DimArg>& reduce_args) {
const std::vector<ExprHandle>& reduce_dims) {
return Reduce(
func_name,
dim_args,
dims,
reducer,
[&](ParameterList p) { return ExprHandle(reducer.initializer()); },
body_func,
reduce_args);
reduce_dims);
}
// Overload which allows inline lambda functions for the body_func.
template <typename BodyFunc>
Tensor Reduce(
const std::string& func_name,
const std::vector<DimArg>& dim_args,
const std::vector<ExprHandle>& dims,
const Reducer& reducer,
const BodyFunc&& body_func,
const std::vector<DimArg>& reduce_args) {
return Reduce(func_name, dim_args, reducer, body_func, reduce_args);
const std::vector<ExprHandle>& reduce_dims) {
return Reduce(func_name, dims, reducer, body_func, reduce_dims);
}
TORCH_API Tensor Reduce(
const std::string& name,
const std::vector<DimArg>& dim_args,
const std::vector<ExprHandle>& dims,
const Reducer& reducer,
const BufHandle& buffer,
const std::vector<DimArg>& reduce_args);
const std::vector<ExprHandle>& reduce_dims);
// Overload for the common case of all dimensions of a prevously Computed
// Tensor.
TORCH_API Tensor Reduce(
const std::string& func_name,
const std::vector<DimArg>& dim_args,
const std::vector<ExprHandle>& dims,
const Reducer& reducer,
Tensor tensor,
const std::vector<DimArg>& reduce_args);
const std::vector<ExprHandle>& reduce_dims);
template <typename... Ts>
inline ExprHandle Tensor::load(const Ts&... ts) const {

View File

@ -278,17 +278,10 @@ void initTensorExprBindings(PyObject* module) {
self->set_src_value(value.node());
});
py::class_<DimArg>(te, "DimArg")
.def(py::init<const ExprHandle&>())
.def(py::init<const ExprHandle&, const std::string&>());
py::implicitly_convertible<ExprHandle, DimArg>();
py::implicitly_convertible<int32_t, DimArg>();
py::implicitly_convertible<int64_t, DimArg>();
te.def(
"Compute",
[](const std::string& func_name,
const std::vector<DimArg>& dim_args,
const std::vector<ExprHandle>& dim_args,
py::function func) {
if (dim_args.size() == 1) {
return Compute(func_name, dim_args, [&func](const VarHandle& a) {
@ -329,7 +322,7 @@ void initTensorExprBindings(PyObject* module) {
te.def(
"Compute2",
[](const std::string& func_name,
const std::vector<DimArg>& dim_args,
const std::vector<ExprHandle>& dim_args,
py::function func) {
return Compute(
func_name, dim_args, [&func](const std::vector<VarHandle>& dims) {
@ -348,10 +341,10 @@ void initTensorExprBindings(PyObject* module) {
te.def(
"Reduce",
[](const std::string& func_name,
const std::vector<DimArg>& dim_args,
const std::vector<ExprHandle>& dim_args,
const Reducer& reducer,
Tensor buffer,
const std::vector<DimArg>& reduce_args) {
const std::vector<ExprHandle>& reduce_args) {
return Reduce(func_name, dim_args, reducer, buffer, reduce_args);
},
py::return_value_policy::reference);
@ -359,34 +352,34 @@ void initTensorExprBindings(PyObject* module) {
te.def(
"Reduce",
[](const std::string& func_name,
const std::vector<DimArg>& dim_args,
const std::vector<ExprHandle>& dim_args,
const Reducer& reducer,
const BufHandle& buffer,
const std::vector<DimArg>& reduce_args) {
const std::vector<ExprHandle>& reduce_args) {
return Reduce(func_name, dim_args, reducer, buffer, reduce_args);
},
py::return_value_policy::reference);
te.def(
"Reduce",
[](const std::string& func_name,
const std::vector<DimArg>& dim_args,
const std::vector<ExprHandle>& dim_args,
const Reducer& reducer,
const std::function<ExprHandle(const std::vector<VarHandle>&)>&
body_func,
const std::vector<DimArg>& reduce_args) {
const std::vector<ExprHandle>& reduce_args) {
return Reduce(func_name, dim_args, reducer, body_func, reduce_args);
},
py::return_value_policy::reference);
te.def(
"Reduce",
[](const std::string& func_name,
const std::vector<DimArg>& dim_args,
const std::vector<ExprHandle>& dim_args,
const Reducer& reducer,
const std::function<ExprHandle(const std::vector<VarHandle>&)>&
init_func,
const std::function<ExprHandle(const std::vector<VarHandle>&)>&
body_func,
const std::vector<DimArg>& reduce_args) {
const std::vector<ExprHandle>& reduce_args) {
return Reduce(func_name, dim_args, reducer, body_func, reduce_args);
},
py::return_value_policy::reference);