mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[TensorExpr] Delet DimArg class. (#72390)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72390
This class didn't add much value and only caused more boilerplate code.
This change removes the class and updates all the use cases with
uses of `ExprHandle`.
A side effect of this change is different names in loop variables, which
caused massive mechanical changes in our tests.
Test Plan: Imported from OSS
Reviewed By: navahgar
Differential Revision: D34030296
Pulled By: ZolotukhinM
fbshipit-source-id: 2ba4e313506a43ab129a10d99e72b638b7d40108
(cherry picked from commit c2ec46a058)
This commit is contained in:
parent
9123e9b3b5
commit
1855b14922
|
|
@ -82,10 +82,8 @@ BENCHMARK_DEFINE_F(BatchNorm, NNC)(benchmark::State& state) {
|
|||
VarHandle eps("eps", kFloat);
|
||||
|
||||
using axis = const VarHandle&;
|
||||
Tensor output = Compute(
|
||||
"output",
|
||||
{{N_, "N"}, {C_, "C"}, {H_, "H"}, {W_, "W"}},
|
||||
[&](axis n, axis c, axis h, axis w) {
|
||||
Tensor output =
|
||||
Compute("output", {N_, C_, H_, W_}, [&](axis n, axis c, axis h, axis w) {
|
||||
// Compute affine terms.
|
||||
auto inv_var = FloatImm::make(1.0f) / sqrt(var.load(c) + eps);
|
||||
auto weight_v = weight.load(c);
|
||||
|
|
@ -143,10 +141,8 @@ BENCHMARK_DEFINE_F(BatchNorm, NNCRelu)(benchmark::State& state) {
|
|||
VarHandle eps("eps", kFloat);
|
||||
|
||||
using axis = const VarHandle&;
|
||||
Tensor output = Compute(
|
||||
"output",
|
||||
{{N_, "N"}, {C_, "C"}, {H_, "H"}, {W_, "W"}},
|
||||
[&](axis n, axis c, axis h, axis w) {
|
||||
Tensor output =
|
||||
Compute("output", {N_, C_, H_, W_}, [&](axis n, axis c, axis h, axis w) {
|
||||
// Compute affine terms.
|
||||
auto inv_var = FloatImm::make(1.0f) / sqrt(var.load(c) + eps);
|
||||
auto weight_v = weight.load(c);
|
||||
|
|
|
|||
|
|
@ -12,26 +12,21 @@ static void BM_CompileSwish(benchmark::State& state) {
|
|||
constexpr int N = 512;
|
||||
te::VarHandle n("n", te::kInt);
|
||||
te::BufHandle A("A", {N}, te::kFloat);
|
||||
te::Tensor relu =
|
||||
te::Compute("relu", {{n, "n"}}, [&](const te::VarHandle& i) {
|
||||
return te::Max::make(A.load(i), 0.f, false);
|
||||
});
|
||||
te::Tensor min6 =
|
||||
te::Compute("min6", {{n, "n"}}, [&](const te::VarHandle& i) {
|
||||
return te::Min::make(relu.load(i), 6.f, false);
|
||||
});
|
||||
te::Tensor plus3 =
|
||||
te::Compute("plus3", {{n, "n"}}, [&](const te::VarHandle& i) {
|
||||
return min6.load(i) + 3.f;
|
||||
});
|
||||
te::Tensor times =
|
||||
te::Compute("times", {{n, "n"}}, [&](const te::VarHandle& i) {
|
||||
return A.load(i) * plus3.load(i);
|
||||
});
|
||||
te::Tensor sixth =
|
||||
te::Compute("sixth", {{n, "n"}}, [&](const te::VarHandle& i) {
|
||||
return times.load(i) * 1.f / 6.f;
|
||||
});
|
||||
te::Tensor relu = te::Compute("relu", {n}, [&](const te::VarHandle& i) {
|
||||
return te::Max::make(A.load(i), 0.f, false);
|
||||
});
|
||||
te::Tensor min6 = te::Compute("min6", {n}, [&](const te::VarHandle& i) {
|
||||
return te::Min::make(relu.load(i), 6.f, false);
|
||||
});
|
||||
te::Tensor plus3 = te::Compute("plus3", {n}, [&](const te::VarHandle& i) {
|
||||
return min6.load(i) + 3.f;
|
||||
});
|
||||
te::Tensor times = te::Compute("times", {n}, [&](const te::VarHandle& i) {
|
||||
return A.load(i) * plus3.load(i);
|
||||
});
|
||||
te::Tensor sixth = te::Compute("sixth", {n}, [&](const te::VarHandle& i) {
|
||||
return times.load(i) * 1.f / 6.f;
|
||||
});
|
||||
te::LoopNest nest({sixth}, {relu, min6, plus3, times, sixth});
|
||||
for (auto tensor : {relu, min6, plus3, times}) {
|
||||
nest.computeInline(tensor.buf());
|
||||
|
|
@ -46,26 +41,20 @@ static void BM_CompileSwishLLVMOnly(benchmark::State& state) {
|
|||
constexpr int N = 512;
|
||||
te::VarHandle n("n", te::kInt);
|
||||
te::BufHandle A("A", {N}, te::kFloat);
|
||||
te::Tensor relu =
|
||||
te::Compute("relu", {{n, "n"}}, [&](const te::VarHandle& i) {
|
||||
return te::Max::make(A.load(i), 0.f, false);
|
||||
});
|
||||
te::Tensor min6 =
|
||||
te::Compute("min6", {{n, "n"}}, [&](const te::VarHandle& i) {
|
||||
return te::Min::make(relu.load(i), 6.f, false);
|
||||
});
|
||||
te::Tensor plus3 =
|
||||
te::Compute("plus3", {{n, "n"}}, [&](const te::VarHandle& i) {
|
||||
return min6.load(i) + 3.f;
|
||||
});
|
||||
te::Tensor times =
|
||||
te::Compute("times", {{n, "n"}}, [&](const te::VarHandle& i) {
|
||||
return A.load(i) * plus3.load(i);
|
||||
});
|
||||
te::Tensor sixth =
|
||||
te::Compute("sixth", {{n, "n"}}, [&](const te::VarHandle& i) {
|
||||
return times.load(i) * 1.f / 6.f;
|
||||
});
|
||||
te::Tensor relu = te::Compute("relu", {n}, [&](const te::VarHandle& i) {
|
||||
return te::Max::make(A.load(i), 0.f, false);
|
||||
});
|
||||
te::Tensor min6 = te::Compute("min6", {n}, [&](const te::VarHandle& i) {
|
||||
return te::Min::make(relu.load(i), 6.f, false);
|
||||
});
|
||||
te::Tensor plus3 = te::Compute(
|
||||
"plus3", {n}, [&](const te::VarHandle& i) { return min6.load(i) + 3.f; });
|
||||
te::Tensor times = te::Compute("times", {n}, [&](const te::VarHandle& i) {
|
||||
return A.load(i) * plus3.load(i);
|
||||
});
|
||||
te::Tensor sixth = te::Compute("sixth", {n}, [&](const te::VarHandle& i) {
|
||||
return times.load(i) * 1.f / 6.f;
|
||||
});
|
||||
te::LoopNest nest({sixth}, {relu, min6, plus3, times, sixth});
|
||||
for (auto tensor : {relu, min6, plus3, times}) {
|
||||
nest.computeInline(tensor.buf());
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ class ConcatBench : public benchmark::Fixture {
|
|||
|
||||
Tensor output = Compute(
|
||||
"aten_cat",
|
||||
{{output_size_[0], "M"}, {output_size_[1], "N"}},
|
||||
{output_size_[0], output_size_[1]},
|
||||
[&](const VarHandle& m, const VarHandle& n) {
|
||||
int d = 0;
|
||||
std::vector<int> cumulative_concat_dim_sizes(num_inputs);
|
||||
|
|
|
|||
|
|
@ -44,12 +44,12 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprNoopt)(benchmark::State& state) {
|
|||
te::BufHandle BP("B", {K, N}, te::kFloat);
|
||||
te::Tensor CT = te::Reduce(
|
||||
"gemm",
|
||||
{{M, "M"}, {N, "N"}},
|
||||
{M, N},
|
||||
te::Sum(),
|
||||
[&](const te::ExprHandle& m,
|
||||
const te::ExprHandle& n,
|
||||
const te::ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); },
|
||||
{{K, "K"}});
|
||||
{K});
|
||||
te::LoopNest loop({CT});
|
||||
loop.prepareForCodegen();
|
||||
te::StmtPtr s = loop.root_stmt();
|
||||
|
|
@ -66,12 +66,12 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile32x32)(benchmark::State& state) {
|
|||
te::BufHandle BP("B", {K, N}, te::kFloat);
|
||||
te::Tensor CT = te::Reduce(
|
||||
"gemm",
|
||||
{{M, "M"}, {N, "N"}},
|
||||
{M, N},
|
||||
te::Sum(),
|
||||
[&](const te::ExprHandle& m,
|
||||
const te::ExprHandle& n,
|
||||
const te::ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); },
|
||||
{{K, "K"}});
|
||||
{K});
|
||||
te::LoopNest loop({CT});
|
||||
|
||||
{
|
||||
|
|
@ -124,12 +124,12 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16)(benchmark::State& state) {
|
|||
te::BufHandle BP("B", {K, N}, te::kFloat);
|
||||
te::Tensor CT = te::Reduce(
|
||||
"gemm",
|
||||
{{M, "M"}, {N, "N"}},
|
||||
{M, N},
|
||||
te::Sum(),
|
||||
[&](const te::ExprHandle& m,
|
||||
const te::ExprHandle& n,
|
||||
const te::ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); },
|
||||
{{K, "K"}});
|
||||
{K});
|
||||
te::LoopNest loop({CT});
|
||||
|
||||
{
|
||||
|
|
@ -182,12 +182,12 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16VecUnroll)(benchmark::State& state) {
|
|||
te::BufHandle BP("B", {K, N}, te::kFloat);
|
||||
te::Tensor CT = te::Reduce(
|
||||
"gemm",
|
||||
{{M, "M"}, {N, "N"}},
|
||||
{M, N},
|
||||
te::Sum(),
|
||||
[&](const te::ExprHandle& m,
|
||||
const te::ExprHandle& n,
|
||||
const te::ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); },
|
||||
{{K, "K"}});
|
||||
{K});
|
||||
te::LoopNest loop({CT});
|
||||
|
||||
{
|
||||
|
|
@ -248,12 +248,12 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16Cache)(benchmark::State& state) {
|
|||
te::BufHandle BP("B", {K, N}, te::kFloat);
|
||||
te::Tensor CT = te::Reduce(
|
||||
"gemm",
|
||||
{{M, "M"}, {N, "N"}},
|
||||
{M, N},
|
||||
te::Sum(),
|
||||
[&](const te::ExprHandle& m,
|
||||
const te::ExprHandle& n,
|
||||
const te::ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); },
|
||||
{{K, "K"}});
|
||||
{K});
|
||||
te::LoopNest loop({CT});
|
||||
|
||||
{
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ class ParallelAdd : public benchmark::Fixture {
|
|||
BENCHMARK_DEFINE_F(ParallelAdd, Simple)(benchmark::State& state) {
|
||||
BufHandle a_buf("a", {M}, kFloat);
|
||||
BufHandle b_buf("b", {M}, kFloat);
|
||||
Tensor c_tensor = Compute("c", {{M, "m"}}, [&](const VarHandle& m) {
|
||||
Tensor c_tensor = Compute("c", {M}, [&](const VarHandle& m) {
|
||||
return a_buf.load(m) + b_buf.load(m);
|
||||
});
|
||||
LoopNest loop_nest({c_tensor});
|
||||
|
|
|
|||
|
|
@ -235,12 +235,12 @@ BENCHMARK_DEFINE_F(Reduce1D, TeNaive)(benchmark::State& state) {
|
|||
te::BufHandle AP("A", {M}, te::kFloat);
|
||||
te::Tensor BT = te::Reduce(
|
||||
"reduce_full",
|
||||
{{1, "N"}},
|
||||
{1},
|
||||
te::Sum(),
|
||||
[&](const te::ExprHandle& n, const te::ExprHandle& m) {
|
||||
return AP.load(m);
|
||||
},
|
||||
{{M, "M"}});
|
||||
{M});
|
||||
|
||||
te::LoopNest loop({BT});
|
||||
loop.prepareForCodegen();
|
||||
|
|
@ -266,12 +266,12 @@ BENCHMARK_DEFINE_F(Reduce1D, TeSplitTail)(benchmark::State& state) {
|
|||
te::BufHandle AP("A", {M}, te::kFloat);
|
||||
te::Tensor BT = te::Reduce(
|
||||
"reduce_full",
|
||||
{{1, "N"}},
|
||||
{1},
|
||||
te::Sum(),
|
||||
[&](const te::ExprHandle& n, const te::ExprHandle& m) {
|
||||
return AP.load(m);
|
||||
},
|
||||
{{M, "M"}});
|
||||
{M});
|
||||
|
||||
te::LoopNest loop({BT});
|
||||
const int kChunkSize = 8;
|
||||
|
|
@ -305,12 +305,12 @@ BENCHMARK_DEFINE_F(Reduce1D, TeSplitMask)(benchmark::State& state) {
|
|||
te::BufHandle AP("A", {M}, te::kFloat);
|
||||
te::Tensor BT = te::Reduce(
|
||||
"reduce_full",
|
||||
{{1, "N"}},
|
||||
{1},
|
||||
te::Sum(),
|
||||
[&](const te::ExprHandle& n, const te::ExprHandle& m) {
|
||||
return AP.load(m);
|
||||
},
|
||||
{{M, "M"}});
|
||||
{M});
|
||||
|
||||
te::LoopNest loop({BT});
|
||||
const int kChunkSize = 8;
|
||||
|
|
@ -349,7 +349,7 @@ BENCHMARK_DEFINE_F(Reduce1D, TeRfactorV1)(benchmark::State& state) {
|
|||
{},
|
||||
te::Sum(),
|
||||
[&](const te::ExprHandle& m) { return AP.load(m); },
|
||||
{{M, "M"}});
|
||||
{M});
|
||||
|
||||
te::LoopNest loop({BT});
|
||||
te::BufPtr rfac_buf;
|
||||
|
|
|
|||
|
|
@ -46,13 +46,13 @@ class SignedLog1pBench : public benchmark::Fixture {
|
|||
"input", {input_size_int_[0], input_size_int_[1]}, kFloat);
|
||||
Tensor abs_result = Compute(
|
||||
"aten_abs",
|
||||
{{input_size_int_[0], "M"}, {input_size_int_[1], "N"}},
|
||||
{input_size_int_[0], input_size_int_[1]},
|
||||
[&](const VarHandle& m, const VarHandle& n) {
|
||||
return abs(input_ph.load(m, n));
|
||||
});
|
||||
Tensor log1p_result = Compute(
|
||||
"aten_log1p",
|
||||
{{input_size_int_[0], "M"}, {input_size_int_[1], "N"}},
|
||||
{input_size_int_[0], input_size_int_[1]},
|
||||
[&](const VarHandle& m, const VarHandle& n) {
|
||||
return log1p(abs_result.load(m, n));
|
||||
});
|
||||
|
|
@ -60,7 +60,7 @@ class SignedLog1pBench : public benchmark::Fixture {
|
|||
computeSign({input_ph}, {input_size_int_[0], input_size_int_[1]});
|
||||
Tensor output = Compute(
|
||||
"aten_mul",
|
||||
{{input_size_int_[0], "M"}, {input_size_int_[1], "N"}},
|
||||
{input_size_int_[0], input_size_int_[1]},
|
||||
[&](const VarHandle& m, const VarHandle& n) {
|
||||
return sign_result.load(m, n) * log1p_result.load(m, n);
|
||||
});
|
||||
|
|
@ -94,13 +94,13 @@ class SignedLog1pBench : public benchmark::Fixture {
|
|||
"input", {input_size_int_[0], input_size_int_[1]}, kFloat);
|
||||
Tensor abs_result = Compute(
|
||||
"aten_abs",
|
||||
{{input_size_int_[0], "M"}, {input_size_int_[1], "N"}},
|
||||
{input_size_int_[0], input_size_int_[1]},
|
||||
[&](const VarHandle& m, const VarHandle& n) {
|
||||
return abs(input_ph.load(m, n));
|
||||
});
|
||||
Tensor log_vml_result = Compute(
|
||||
"aten_log1p",
|
||||
{{input_size_int_[0], "M"}, {input_size_int_[1], "N"}},
|
||||
{input_size_int_[0], input_size_int_[1]},
|
||||
[&](const VarHandle& m, const VarHandle& n) {
|
||||
return log_vml(abs_result.load(m, n) + ExprHandle(1));
|
||||
});
|
||||
|
|
@ -108,7 +108,7 @@ class SignedLog1pBench : public benchmark::Fixture {
|
|||
computeSign({input_ph}, {input_size_int_[0], input_size_int_[1]});
|
||||
Tensor output = Compute(
|
||||
"aten_mul",
|
||||
{{input_size_int_[0], "M"}, {input_size_int_[1], "N"}},
|
||||
{input_size_int_[0], input_size_int_[1]},
|
||||
[&](const VarHandle& m, const VarHandle& n) {
|
||||
return sign_result.load(m, n) * log_vml_result.load(m, n);
|
||||
});
|
||||
|
|
|
|||
|
|
@ -49,8 +49,7 @@ TEST(BoundsInference, _1) {
|
|||
// {{b, kStore, 0, 99}, {a, kLoad, 0, 99}}
|
||||
ExprHandle n(100);
|
||||
BufHandle a("a", {n}, kFloat);
|
||||
Tensor b =
|
||||
Compute("b", {{n, "i"}}, [&](const VarHandle& i) { return a.load(i); });
|
||||
Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); });
|
||||
LoopNest l({b});
|
||||
auto bounds_info = inferBounds(l.root_stmt());
|
||||
|
||||
|
|
@ -73,8 +72,7 @@ TEST(BoundsInference, _2) {
|
|||
// {{b, kStore, 0, n-1}, {a, kLoad, 0, n-1}}
|
||||
VarHandle n("n", kInt);
|
||||
BufHandle a("a", {n}, kFloat);
|
||||
Tensor b =
|
||||
Compute("b", {{n, "i"}}, [&](const VarHandle& i) { return a.load(i); });
|
||||
Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); });
|
||||
LoopNest l({b});
|
||||
auto bounds_info = inferBounds(l.root_stmt());
|
||||
|
||||
|
|
@ -97,9 +95,8 @@ TEST(BoundsInference, _3) {
|
|||
// {{b, kStore, 0, 99}, {a, kLoad, 0, 109}}
|
||||
ExprHandle n(100);
|
||||
BufHandle a("a", {n + 10}, kFloat);
|
||||
Tensor b = Compute("b", {{n, "i"}}, [&](const VarHandle& i) {
|
||||
return a.load(i) * a.load(i + 10);
|
||||
});
|
||||
Tensor b = Compute(
|
||||
"b", {n}, [&](const VarHandle& i) { return a.load(i) * a.load(i + 10); });
|
||||
LoopNest l({b});
|
||||
auto bounds_info = inferBounds(l.root_stmt());
|
||||
|
||||
|
|
@ -126,14 +123,12 @@ TEST(BoundsInference, _4) {
|
|||
ExprHandle W(320);
|
||||
ExprHandle H(200);
|
||||
BufHandle a("a", {H, W}, kFloat);
|
||||
Tensor b = Compute(
|
||||
"b", {{H, "y"}, {W, "x"}}, [&](const VarHandle& y, const VarHandle& x) {
|
||||
return x * y;
|
||||
});
|
||||
Tensor c = Compute(
|
||||
"c", {{H, "y"}, {W, "x"}}, [&](const VarHandle& y, const VarHandle& x) {
|
||||
return a.load(y, x) * b.load(y, x);
|
||||
});
|
||||
Tensor b = Compute("b", {H, W}, [&](const VarHandle& y, const VarHandle& x) {
|
||||
return x * y;
|
||||
});
|
||||
Tensor c = Compute("c", {H, W}, [&](const VarHandle& y, const VarHandle& x) {
|
||||
return a.load(y, x) * b.load(y, x);
|
||||
});
|
||||
LoopNest l({c});
|
||||
std::vector<ForPtr> loops = l.getLoopStmtsFor(c);
|
||||
StmtPtr body = l.getLoopBodyFor(c);
|
||||
|
|
@ -204,8 +199,7 @@ TEST(BoundsInference, _5) {
|
|||
// b[i_tail + (100/16)*16] = a[i_tail + (100/16)*16];
|
||||
ExprHandle n(100);
|
||||
BufHandle a("a", {n}, kFloat);
|
||||
Tensor b =
|
||||
Compute("b", {{n, "i"}}, [&](const VarHandle& i) { return a.load(i); });
|
||||
Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); });
|
||||
LoopNest l({b});
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
|
|
@ -258,12 +252,11 @@ TEST(BoundsInference, _6) {
|
|||
ExprHandle CW(32);
|
||||
ExprHandle CH(20);
|
||||
BufHandle a("a", {H, W}, kFloat);
|
||||
Tensor b = Compute(
|
||||
"b", {{H, "y"}, {W, "x"}}, [&](const VarHandle& y, const VarHandle& x) {
|
||||
return x * y;
|
||||
});
|
||||
Tensor c = Compute(
|
||||
"c", {{CH, "y"}, {CW, "x"}}, [&](const VarHandle& y, const VarHandle& x) {
|
||||
Tensor b = Compute("b", {H, W}, [&](const VarHandle& y, const VarHandle& x) {
|
||||
return x * y;
|
||||
});
|
||||
Tensor c =
|
||||
Compute("c", {CH, CW}, [&](const VarHandle& y, const VarHandle& x) {
|
||||
return a.load(y + 100, x + 100) * b.load(y * 2, x * 5);
|
||||
});
|
||||
LoopNest l({c});
|
||||
|
|
@ -325,10 +318,9 @@ TEST(BoundsInference, _6) {
|
|||
TEST(BoundsInference, Adjacent) {
|
||||
ExprHandle H(6);
|
||||
BufHandle a("a", {20}, kFloat);
|
||||
Tensor b =
|
||||
Compute("b", {{H, "x"}}, [&](const VarHandle& x) { return a.load(x); });
|
||||
Tensor c = Compute(
|
||||
"c", {{H, "x"}}, [&](const VarHandle& x) { return a.load(x + H); });
|
||||
Tensor b = Compute("b", {H}, [&](const VarHandle& x) { return a.load(x); });
|
||||
Tensor c =
|
||||
Compute("c", {H}, [&](const VarHandle& x) { return a.load(x + H); });
|
||||
LoopNest l({b, c});
|
||||
std::vector<ForPtr> loops = NodeFinder<For>::find(l.root_stmt());
|
||||
|
||||
|
|
@ -383,12 +375,11 @@ TEST(BoundsInference, Adjacent) {
|
|||
|
||||
TEST(BoundsInference, MultipleTopLoopLoad) {
|
||||
BufHandle a("a", {100}, kFloat);
|
||||
Tensor b =
|
||||
Compute("b", {{64, "x"}}, [&](const VarHandle& x) { return a.load(x); });
|
||||
Tensor c = Compute(
|
||||
"c", {{32, "x"}}, [&](const VarHandle& x) { return a.load(x + 10); });
|
||||
Tensor d = Compute(
|
||||
"d", {{96, "x"}}, [&](const VarHandle& x) { return a.load(x + 2); });
|
||||
Tensor b = Compute("b", {64}, [&](const VarHandle& x) { return a.load(x); });
|
||||
Tensor c =
|
||||
Compute("c", {32}, [&](const VarHandle& x) { return a.load(x + 10); });
|
||||
Tensor d =
|
||||
Compute("d", {96}, [&](const VarHandle& x) { return a.load(x + 2); });
|
||||
LoopNest l({b, c, d});
|
||||
|
||||
auto bounds_info = inferBounds(l.root_stmt());
|
||||
|
|
@ -496,16 +487,15 @@ TEST(BoundsInference, MultipleTopLoopStore) {
|
|||
}
|
||||
|
||||
TEST(BoundsInference, CacheReads) {
|
||||
Tensor A = Compute(
|
||||
"A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) {
|
||||
return i * j;
|
||||
});
|
||||
Tensor B = Compute(
|
||||
"B", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
|
||||
return i * j;
|
||||
});
|
||||
Tensor B =
|
||||
Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return A.load(i + 30, j + 3);
|
||||
});
|
||||
Tensor C = Compute(
|
||||
"C", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
Tensor C =
|
||||
Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return A.load(i + 10, j + 20) + A.load(i + 30, j + 40);
|
||||
});
|
||||
|
||||
|
|
@ -562,7 +552,7 @@ TEST(BoundsInference, CacheReads) {
|
|||
TEST(BoundsInference, Flattened) {
|
||||
Tensor b = Compute(
|
||||
"b",
|
||||
{{3, "z"}, {4, "y"}, {5, "x"}},
|
||||
{3, 4, 5},
|
||||
[&](const VarHandle& z, const VarHandle& y, const VarHandle& x) {
|
||||
return x * y + z;
|
||||
});
|
||||
|
|
@ -637,14 +627,12 @@ TEST(BoundsInference, GetPotentialHazards) {
|
|||
}
|
||||
|
||||
TEST(BoundsInference, GetPotentialHazardsLoopNoHazard) {
|
||||
Tensor A = Compute(
|
||||
"A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) {
|
||||
return i * j;
|
||||
});
|
||||
Tensor B = Compute(
|
||||
"B", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) {
|
||||
return (i + 1) * (j + 1);
|
||||
});
|
||||
Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
|
||||
return i * j;
|
||||
});
|
||||
Tensor B = Compute("B", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
|
||||
return (i + 1) * (j + 1);
|
||||
});
|
||||
|
||||
LoopNest l({A, B});
|
||||
|
||||
|
|
@ -663,12 +651,11 @@ TEST(BoundsInference, GetPotentialHazardsLoopNoHazard) {
|
|||
}
|
||||
|
||||
TEST(BoundsInference, GetPotentialHazardsLoopCall) {
|
||||
Tensor A = Compute(
|
||||
"A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) {
|
||||
return i * j;
|
||||
});
|
||||
Tensor B = Compute(
|
||||
"B", {{64, "i"}, {64, "j"}}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
|
||||
return i * j;
|
||||
});
|
||||
Tensor B =
|
||||
Compute("B", {64, 64}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return A.load(i, j) + 5;
|
||||
});
|
||||
|
||||
|
|
@ -688,10 +675,9 @@ TEST(BoundsInference, GetPotentialHazardsLoopCall) {
|
|||
}
|
||||
|
||||
TEST(BoundsInference, GetPotentialHazardsLoopSplit) {
|
||||
Tensor A = Compute(
|
||||
"A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) {
|
||||
return i * j;
|
||||
});
|
||||
Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
|
||||
return i * j;
|
||||
});
|
||||
|
||||
LoopNest l({A});
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
|
|
|
|||
|
|
@ -191,7 +191,7 @@ TEST(Conv, Conv2D) {
|
|||
|
||||
te::Tensor conv = te::Reduce(
|
||||
"conv",
|
||||
{{N, "n"}, {K, "k"}, {OH, "oh"}, {OW, "ow"}},
|
||||
{N, K, OH, OW},
|
||||
te::Sum(),
|
||||
// FIXME: We have to use a `std::vector` parameter here and then unpack
|
||||
// it, because we don't have an overload allowing for an arbitrary number
|
||||
|
|
@ -211,7 +211,7 @@ TEST(Conv, Conv2D) {
|
|||
},
|
||||
// FIXME: If you forget one of the reduction dims, you get a segfault.
|
||||
// Could that be caught by a verifier?
|
||||
{{C, "c"}, {R, "r"}, {S, "s"}});
|
||||
{C, R, S});
|
||||
|
||||
// FIXME: It'd be nice to have a single header that pulls in things like
|
||||
// LoopNest, IRSimplifier, etc.
|
||||
|
|
|
|||
|
|
@ -37,9 +37,9 @@ static void testCudaTestVectorAdd01_impl() {
|
|||
Tensor c = Compute(
|
||||
"c",
|
||||
{
|
||||
{num_iter, "n"},
|
||||
{block_count, "b_id"},
|
||||
{block_size, "t_id"},
|
||||
num_iter,
|
||||
block_count,
|
||||
block_size,
|
||||
},
|
||||
[&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) {
|
||||
return a_buf.load(n, b_id, t_id) + b_buf.load(n, b_id, t_id);
|
||||
|
|
@ -101,9 +101,9 @@ TEST(Cuda, Sigmoid_CUDA) {
|
|||
Tensor c = Compute(
|
||||
"c",
|
||||
{
|
||||
{num_iter, "n"},
|
||||
{block_count, "b_id"},
|
||||
{block_size, "t_id"},
|
||||
num_iter,
|
||||
block_count,
|
||||
block_size,
|
||||
},
|
||||
[&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) {
|
||||
return sigmoid(sigmoid(a_buf.load(n, b_id, t_id)));
|
||||
|
|
@ -163,12 +163,9 @@ TEST(Cuda, TestVectorAdd01_CUDA) {
|
|||
static void testCudaTestVectorAdd02_impl(int64_t N, int64_t block_size) {
|
||||
BufHandle a_buf("a", {N}, kFloat);
|
||||
BufHandle b_buf("b", {N}, kFloat);
|
||||
Tensor c = Compute(
|
||||
"c",
|
||||
{
|
||||
{N, "N"},
|
||||
},
|
||||
[&](const VarHandle& n) { return a_buf.load(n) + b_buf.load(n); });
|
||||
Tensor c = Compute("c", {N}, [&](const VarHandle& n) {
|
||||
return a_buf.load(n) + b_buf.load(n);
|
||||
});
|
||||
LoopNest l({c});
|
||||
ForPtr n_inner;
|
||||
std::vector<ForPtr> loops = l.getLoopStmtsFor(c);
|
||||
|
|
@ -222,7 +219,7 @@ TEST(Cuda, TestVectorAdd02_CUDA) {
|
|||
TEST(Cuda, HalfCast_CUDA) {
|
||||
auto half = ToDtype<at::Half>();
|
||||
BufHandle a("a", {4}, half);
|
||||
Tensor b = Compute("b", {{4, "n"}}, [&](const VarHandle& i) {
|
||||
Tensor b = Compute("b", {4}, [&](const VarHandle& i) {
|
||||
return Cast::make(kFloat, a.load(i));
|
||||
});
|
||||
|
||||
|
|
@ -263,8 +260,8 @@ TEST(Cuda, DynamicShape2D_CUDA) {
|
|||
VarHandle n("n", kInt);
|
||||
BufHandle a("a", {m, n}, kFloat);
|
||||
BufHandle b("b", {m, n}, kFloat);
|
||||
Tensor c = Compute(
|
||||
"c", {{m, "m"}, {n, "n"}}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
Tensor c =
|
||||
Compute("c", {m, n}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return a.load(i, j) + b.load(i, j);
|
||||
});
|
||||
LoopNest l({c});
|
||||
|
|
@ -326,9 +323,9 @@ TEST(Cuda, TestRand01_CUDA) {
|
|||
Tensor c = Compute(
|
||||
"c",
|
||||
{
|
||||
{num_iter, "n"},
|
||||
{block_count, "b_id"},
|
||||
{block_size, "t_id"},
|
||||
num_iter,
|
||||
block_count,
|
||||
block_size,
|
||||
},
|
||||
[&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) {
|
||||
return Intrinsics::make(IntrinsicsOp::kRand, kFloat);
|
||||
|
|
@ -381,8 +378,8 @@ TEST(Cuda, DynamicShapeSplit_CUDA) {
|
|||
constexpr int64_t N = 4096;
|
||||
VarHandle n("n", kLong);
|
||||
BufHandle a("a", {n}, kFloat);
|
||||
Tensor b = Compute(
|
||||
"b", {{n, "n"}}, [&](const VarHandle& i) { return a.load(i) * 2.0f; });
|
||||
Tensor b =
|
||||
Compute("b", {n}, [&](const VarHandle& i) { return a.load(i) * 2.0f; });
|
||||
LoopNest l({b});
|
||||
ForPtr inner;
|
||||
std::vector<ForPtr> loops = l.getLoopStmtsFor(b);
|
||||
|
|
@ -914,15 +911,15 @@ TEST(Cuda, LocalMemReduce_1_CUDA) {
|
|||
TEST(Cuda, HalfSupport_CUDA) {
|
||||
auto half = ToDtype<at::Half>();
|
||||
BufHandle a("a", {4}, half);
|
||||
Tensor b = Compute("b", {{4, "n"}}, [&](const VarHandle& i) {
|
||||
Tensor b = Compute("b", {4}, [&](const VarHandle& i) {
|
||||
return Cast::make(half, ExprHandle(2.0f) * a.load(i));
|
||||
});
|
||||
|
||||
Tensor c = Compute("c", {{4, "n"}}, [&](const VarHandle& i) {
|
||||
Tensor c = Compute("c", {4}, [&](const VarHandle& i) {
|
||||
return Cast::make(kFloat, Cast::make(half, ExprHandle(42)) + b.load(i));
|
||||
});
|
||||
|
||||
Tensor d = Compute("d", {{4, "n"}}, [&](const VarHandle& i) {
|
||||
Tensor d = Compute("d", {4}, [&](const VarHandle& i) {
|
||||
return Cast::make(half, c.load(i));
|
||||
});
|
||||
|
||||
|
|
@ -971,7 +968,7 @@ TEST(Cuda, HalfSupport_CUDA) {
|
|||
TEST(Cuda, HalfPropagation_CUDA) {
|
||||
auto half = ToDtype<at::Half>();
|
||||
BufHandle a("a", {4}, half);
|
||||
Tensor relu = Compute("relu", {{4, "n"}}, [&](const VarHandle& i) {
|
||||
Tensor relu = Compute("relu", {4}, [&](const VarHandle& i) {
|
||||
return Max::make(a.load(i), ExprHandle(alloc<HalfImm>(0)), true);
|
||||
});
|
||||
|
||||
|
|
@ -987,8 +984,8 @@ TEST(Cuda, HalfPropagation_CUDA) {
|
|||
const std::string& verification_pattern =
|
||||
R"IR(
|
||||
# CHECK: for (
|
||||
# CHECK: float v = float(a[n]);
|
||||
# CHECK: relu[n] = half(Max(v, 0.f
|
||||
# CHECK: float v = float(a[i]);
|
||||
# CHECK: relu[i] = half(Max(v, 0.f
|
||||
# CHECK: })IR";
|
||||
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
|
|
@ -1020,7 +1017,7 @@ TEST(Cuda, UnusedHalfArgument_CUDA) {
|
|||
BufHandle a("a", {4}, kFloat);
|
||||
auto half = ToDtype<at::Half>();
|
||||
BufHandle b("b", {4}, half);
|
||||
Tensor relu = Compute("relu", {{4, "n"}}, [&](const VarHandle& i) {
|
||||
Tensor relu = Compute("relu", {4}, [&](const VarHandle& i) {
|
||||
return Max::make(a.load(i), ExprHandle(alloc<FloatImm>(0)), true);
|
||||
});
|
||||
|
||||
|
|
@ -1036,8 +1033,8 @@ TEST(Cuda, UnusedHalfArgument_CUDA) {
|
|||
const std::string& verification_pattern =
|
||||
R"IR(
|
||||
# CHECK: for (
|
||||
# CHECK: float v = a[n];
|
||||
# CHECK: relu[n] = Max(v, 0.f
|
||||
# CHECK: float v = a[i];
|
||||
# CHECK: relu[i] = Max(v, 0.f
|
||||
# CHECK: })IR";
|
||||
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
|
|
@ -1150,10 +1147,9 @@ TEST(Cuda, MaskBlockDim_CUDA) {
|
|||
int B_SIZE = 50;
|
||||
BufHandle a_buf("a", {A_SIZE}, kFloat);
|
||||
BufHandle b_buf("b", {B_SIZE}, kFloat);
|
||||
Tensor c = Compute("c", {{A_SIZE, "i"}}, [&](const VarHandle& i) {
|
||||
return a_buf.load(i) + 10;
|
||||
});
|
||||
Tensor d = Compute("d", {{B_SIZE, "i"}}, [&](const VarHandle& i) {
|
||||
Tensor c = Compute(
|
||||
"c", {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; });
|
||||
Tensor d = Compute("d", {B_SIZE}, [&](const VarHandle& i) {
|
||||
return a_buf.load(i) + b_buf.load(i);
|
||||
});
|
||||
|
||||
|
|
@ -1242,10 +1238,9 @@ TEST(Cuda, MaskThreadDim_CUDA) {
|
|||
int B_SIZE = 100;
|
||||
BufHandle a_buf("a", {A_SIZE}, kFloat);
|
||||
BufHandle b_buf("b", {B_SIZE}, kFloat);
|
||||
Tensor c = Compute("c", {{A_SIZE, "i"}}, [&](const VarHandle& i) {
|
||||
return a_buf.load(i) + 10;
|
||||
});
|
||||
Tensor d = Compute("d", {{B_SIZE, "i"}}, [&](const VarHandle& i) {
|
||||
Tensor c = Compute(
|
||||
"c", {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; });
|
||||
Tensor d = Compute("d", {B_SIZE}, [&](const VarHandle& i) {
|
||||
return a_buf.load(i / 2) + b_buf.load(i);
|
||||
});
|
||||
|
||||
|
|
@ -1336,10 +1331,9 @@ TEST(Cuda, MaskMultiBlockDim_CUDA) {
|
|||
int B_SIZE = 50;
|
||||
BufHandle a_buf("a", {A_SIZE}, kFloat);
|
||||
BufHandle b_buf("b", {B_SIZE}, kFloat);
|
||||
Tensor c = Compute("c", {{A_SIZE, "i"}}, [&](const VarHandle& i) {
|
||||
return a_buf.load(i) + 10;
|
||||
});
|
||||
Tensor d = Compute("d", {{B_SIZE, "i"}}, [&](const VarHandle& i) {
|
||||
Tensor c = Compute(
|
||||
"c", {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; });
|
||||
Tensor d = Compute("d", {B_SIZE}, [&](const VarHandle& i) {
|
||||
return a_buf.load(i) + b_buf.load(i);
|
||||
});
|
||||
|
||||
|
|
@ -1429,10 +1423,9 @@ TEST(Cuda, MaskBlockAndThreadDim_CUDA) {
|
|||
int B_SIZE = 50;
|
||||
BufHandle a_buf("a", {A_SIZE}, kFloat);
|
||||
BufHandle b_buf("b", {B_SIZE}, kFloat);
|
||||
Tensor c = Compute("c", {{A_SIZE, "i"}}, [&](const VarHandle& i) {
|
||||
return a_buf.load(i) + 10;
|
||||
});
|
||||
Tensor d = Compute("d", {{B_SIZE, "i"}}, [&](const VarHandle& i) {
|
||||
Tensor c = Compute(
|
||||
"c", {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; });
|
||||
Tensor d = Compute("d", {B_SIZE}, [&](const VarHandle& i) {
|
||||
return a_buf.load(i) + b_buf.load(i);
|
||||
});
|
||||
|
||||
|
|
@ -1522,15 +1515,11 @@ TEST(Cuda, MaskMultiDim_CUDA) {
|
|||
BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat);
|
||||
BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat);
|
||||
Tensor c = Compute(
|
||||
"C",
|
||||
{{OUTER_SIZE, "i"}, {A_SIZE, "j"}},
|
||||
[&](const VarHandle& i, const VarHandle& j) {
|
||||
"C", {OUTER_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return ExprHandle(2) * a_buf.load(i, j);
|
||||
});
|
||||
Tensor d = Compute(
|
||||
"D",
|
||||
{{OUTER_SIZE, "i"}, {B_SIZE, "j"}},
|
||||
[&](const VarHandle& i, const VarHandle& j) {
|
||||
"D", {OUTER_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return c.load(i, j * 2) + b_buf.load(i, j);
|
||||
});
|
||||
|
||||
|
|
@ -1651,15 +1640,11 @@ TEST(Cuda, MaskMultiDimSymbolic_CUDA) {
|
|||
BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat);
|
||||
BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat);
|
||||
Tensor c = Compute(
|
||||
"C",
|
||||
{{OUTER_SIZE, "i"}, {A_SIZE, "j"}},
|
||||
[&](const VarHandle& i, const VarHandle& j) {
|
||||
"C", {OUTER_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return ExprHandle(2) * a_buf.load(i, j);
|
||||
});
|
||||
Tensor d = Compute(
|
||||
"D",
|
||||
{{OUTER_SIZE, "i"}, {B_SIZE, "j"}},
|
||||
[&](const VarHandle& i, const VarHandle& j) {
|
||||
"D", {OUTER_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return c.load(i, j * 2) + b_buf.load(i, j);
|
||||
});
|
||||
|
||||
|
|
@ -2062,15 +2047,11 @@ TEST(Cuda, MaskMultiDimMultiAxis_CUDA) {
|
|||
BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat);
|
||||
BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat);
|
||||
Tensor c = Compute(
|
||||
"C",
|
||||
{{OUTER_SIZE, "i"}, {A_SIZE, "j"}},
|
||||
[&](const VarHandle& i, const VarHandle& j) {
|
||||
"C", {OUTER_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return ExprHandle(2) * a_buf.load(i, j);
|
||||
});
|
||||
Tensor d = Compute(
|
||||
"D",
|
||||
{{OUTER_SIZE, "i"}, {B_SIZE, "j"}},
|
||||
[&](const VarHandle& i, const VarHandle& j) {
|
||||
"D", {OUTER_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return c.load(i, j * 2) + b_buf.load(i, j);
|
||||
});
|
||||
|
||||
|
|
@ -2192,15 +2173,11 @@ TEST(Cuda, MaskMultiDimMultiLevel_CUDA) {
|
|||
BufHandle a_buf("a", {OUTER_A_SIZE, A_SIZE}, kFloat);
|
||||
BufHandle b_buf("b", {OUTER_B_SIZE, B_SIZE}, kFloat);
|
||||
Tensor c = Compute(
|
||||
"C",
|
||||
{{OUTER_A_SIZE, "i"}, {A_SIZE, "j"}},
|
||||
[&](const VarHandle& i, const VarHandle& j) {
|
||||
"C", {OUTER_A_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return ExprHandle(2) * a_buf.load(i, j);
|
||||
});
|
||||
Tensor d = Compute(
|
||||
"D",
|
||||
{{OUTER_B_SIZE, "i"}, {B_SIZE, "j"}},
|
||||
[&](const VarHandle& i, const VarHandle& j) {
|
||||
"D", {OUTER_B_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return c.load(i, j * 2) + b_buf.load(i, j);
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -777,14 +777,14 @@ TEST(ExternalCall, ComputeInterop) {
|
|||
|
||||
Tensor Input = Compute(
|
||||
"Input",
|
||||
{{1, "n"}, {16, "c"}, {32, "h"}, {32, "w"}},
|
||||
{1, 16, 32, 32},
|
||||
[&](const VarHandle& n,
|
||||
const VarHandle& c,
|
||||
const VarHandle& h,
|
||||
const VarHandle& w) { return FloatImm::make(5.0f); });
|
||||
Tensor Weight = Compute(
|
||||
"Weight",
|
||||
{{16, "n"}, {16, "c"}, {1, "kh"}, {1, "kw"}},
|
||||
{16, 16, 1, 1},
|
||||
[&](const VarHandle& n,
|
||||
const VarHandle& c,
|
||||
const VarHandle& h,
|
||||
|
|
@ -806,7 +806,7 @@ TEST(ExternalCall, ComputeInterop) {
|
|||
{}));
|
||||
Tensor Result = Compute(
|
||||
"Result",
|
||||
{{1, "n"}, {16, "c"}, {32, "h"}, {32, "w"}},
|
||||
{1, 16, 32, 32},
|
||||
[&](const VarHandle& n,
|
||||
const VarHandle& c,
|
||||
const VarHandle& h,
|
||||
|
|
@ -866,14 +866,12 @@ TEST(ExternalCall, Inlining) {
|
|||
|
||||
BufHandle MatmulResultBuf("MatmulResult", {8, 8}, kFloat);
|
||||
|
||||
Tensor A = Compute(
|
||||
"A", {{8, "i"}, {8, "j"}}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return FloatImm::make(5.0f);
|
||||
});
|
||||
Tensor B = Compute(
|
||||
"B", {{8, "i"}, {8, "j"}}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return FloatImm::make(4.0f);
|
||||
});
|
||||
Tensor A = Compute("A", {8, 8}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return FloatImm::make(5.0f);
|
||||
});
|
||||
Tensor B = Compute("B", {8, 8}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return FloatImm::make(4.0f);
|
||||
});
|
||||
Tensor MatmulResult = Tensor(
|
||||
MatmulResultBuf.node(),
|
||||
ExternalCall::make(
|
||||
|
|
@ -881,10 +879,8 @@ TEST(ExternalCall, Inlining) {
|
|||
"nnc_aten_matmul",
|
||||
{BufHandle(A.buf()), BufHandle(B.buf())},
|
||||
{}));
|
||||
Tensor Result = Compute(
|
||||
"Result",
|
||||
{{8, "i"}, {8, "j"}},
|
||||
[&](const VarHandle& i, const VarHandle& j) {
|
||||
Tensor Result =
|
||||
Compute("Result", {8, 8}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return MatmulResult.load(i, j) + FloatImm::make(3.0f);
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -53,42 +53,36 @@ TEST(IRPrinter, FunctionName) {
|
|||
int N = 20;
|
||||
|
||||
Tensor producer = Compute(
|
||||
"producer",
|
||||
{{M, "m"}, {N, "n"}},
|
||||
[&](const ExprHandle& m, const ExprHandle& n) { return m * n; });
|
||||
"producer", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return m * n;
|
||||
});
|
||||
|
||||
Tensor chunk_0 = Compute(
|
||||
"chunk",
|
||||
{{M, "m"}, {N / 2, "n"}},
|
||||
[&](const ExprHandle& m, const ExprHandle& n) {
|
||||
"chunk_0", {M, N / 2}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return producer.load(m, n);
|
||||
});
|
||||
|
||||
Tensor chunk_1 = Compute(
|
||||
"chunk",
|
||||
{{M, "m"}, {N / 2, "n"}},
|
||||
[&](const ExprHandle& m, const ExprHandle& n) {
|
||||
"chunk_1", {M, N / 2}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return producer.load(m, n + ExprHandle(N / 2));
|
||||
});
|
||||
|
||||
Tensor consumer = Compute(
|
||||
"consumer",
|
||||
{{M, "i"}, {N / 2, "j"}},
|
||||
[&](const ExprHandle& i, const ExprHandle& j) {
|
||||
"consumer", {M, N / 2}, [&](const ExprHandle& i, const ExprHandle& j) {
|
||||
return i * chunk_1.load(i, j);
|
||||
});
|
||||
|
||||
LoopNest l({chunk_0, chunk_1, consumer});
|
||||
auto body = l.root_stmt();
|
||||
auto body = LoopNest::sanitizeNames(l.root_stmt());
|
||||
|
||||
std::stringstream ss;
|
||||
ss << *body;
|
||||
|
||||
const std::string& verification_pattern =
|
||||
R"IR(
|
||||
# CHECK: for (int i
|
||||
# CHECK: for (int j
|
||||
# CHECK: consumer[i, j] = i * (chunk_1[i, j])IR";
|
||||
# CHECK: for (int i_2
|
||||
# CHECK: for (int j_2
|
||||
# CHECK: consumer[i_2, j_2] = i_2 * (chunk_1[i_2, j_2])IR";
|
||||
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, ss.str());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1602,7 +1602,7 @@ Tensor lowerNanToNum(
|
|||
auto input_buf = c10::get<BufHandle>(inputs[0]);
|
||||
auto e = Compute(
|
||||
"custom_nan_to_num",
|
||||
c10::fmap<DimArg>(outputShape),
|
||||
outputShape,
|
||||
[&](const std::vector<VarHandle>& axes) {
|
||||
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
||||
auto load = input_buf.load(indices);
|
||||
|
|
|
|||
|
|
@ -584,8 +584,7 @@ DOUBLE_INTRINSICS_TEST(lgamma, 4)
|
|||
TEST(LLVM, VectorizerLoadStoreTest) {
|
||||
BufHandle a("A", {1}, kInt);
|
||||
|
||||
Tensor c =
|
||||
Compute("c", {{4, "i"}}, [&](const VarHandle& i) { return a.load(i); });
|
||||
Tensor c = Compute("c", {4}, [&](const VarHandle& i) { return a.load(i); });
|
||||
|
||||
BufHandle c_buf(c.buf());
|
||||
LoopNest l({c});
|
||||
|
|
@ -606,7 +605,7 @@ TEST(LLVM, VectorizerLoadStoreTest) {
|
|||
TEST(LLVM, VectorizeBitCast) {
|
||||
BufHandle a("A", {128}, kInt);
|
||||
|
||||
Tensor c = Compute("c", {{128, "i"}}, [&](const VarHandle& i) {
|
||||
Tensor c = Compute("c", {128}, [&](const VarHandle& i) {
|
||||
return bitcast<float>(a.load(i));
|
||||
});
|
||||
|
||||
|
|
@ -1186,9 +1185,8 @@ TEST(LLVM, StoreFloat) {
|
|||
|
||||
TEST(LLVM, SimpleMath01) {
|
||||
const int N = 1024;
|
||||
Tensor tensor = Compute("f", {{N, "i"}}, [](const VarHandle& i) {
|
||||
return cast<float>(i * i + 1);
|
||||
});
|
||||
Tensor tensor = Compute(
|
||||
"f", {N}, [](const VarHandle& i) { return cast<float>(i * i + 1); });
|
||||
LoopNest l({tensor});
|
||||
StmtPtr stmt = l.root_stmt();
|
||||
BufHandle f_buf(tensor.buf());
|
||||
|
|
@ -1209,9 +1207,8 @@ TEST(LLVM, ComputeMul) {
|
|||
const int N = 1024;
|
||||
BufHandle a("a", {N}, kFloat);
|
||||
BufHandle b("b", {N}, kFloat);
|
||||
Tensor c = Compute("c", {{N, "i"}}, [&](const VarHandle& i) {
|
||||
return a.load(i) * b.load(i);
|
||||
});
|
||||
Tensor c = Compute(
|
||||
"c", {N}, [&](const VarHandle& i) { return a.load(i) * b.load(i); });
|
||||
|
||||
BufHandle c_buf(c.buf());
|
||||
LoopNest l({c});
|
||||
|
|
@ -1232,10 +1229,9 @@ TEST(LLVM, BroadcastAdd) {
|
|||
const int N = 1024;
|
||||
BufHandle a("a", {M, N}, kFloat);
|
||||
BufHandle b("b", {N}, kFloat);
|
||||
Tensor c = Compute(
|
||||
"c", {{M, "i"}, {N, "j"}}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return a.load(i, j) + b.load(j);
|
||||
});
|
||||
Tensor c = Compute("c", {M, N}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return a.load(i, j) + b.load(j);
|
||||
});
|
||||
|
||||
BufHandle c_buf(c.buf());
|
||||
LoopNest l({c});
|
||||
|
|
@ -1333,9 +1329,8 @@ TEST(LLVM, TensorDynamicShapeAdd) {
|
|||
VarHandle n("n", kInt);
|
||||
BufHandle a("a", {n}, kFloat);
|
||||
BufHandle b("b", {n}, kFloat);
|
||||
Tensor c = Compute("c", {{n, "n"}}, [&](const VarHandle& i) {
|
||||
return a.load(i) + b.load(i);
|
||||
});
|
||||
Tensor c = Compute(
|
||||
"c", {n}, [&](const VarHandle& i) { return a.load(i) + b.load(i); });
|
||||
LoopNest l({c});
|
||||
StmtPtr s = l.root_stmt();
|
||||
LLVMCodeGen cg(s, {a, b, c, n});
|
||||
|
|
@ -1356,8 +1351,8 @@ TEST(LLVM, DynamicShape2D) {
|
|||
VarHandle n("n", kInt);
|
||||
BufHandle a("a", {m, n}, kFloat);
|
||||
BufHandle b("b", {m, n}, kFloat);
|
||||
Tensor c = Compute(
|
||||
"c", {{m, "m"}, {n, "n"}}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
Tensor c =
|
||||
Compute("c", {m, n}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return a.load(i, j) + b.load(i, j);
|
||||
});
|
||||
LoopNest l({c});
|
||||
|
|
@ -1386,7 +1381,7 @@ TEST(LLVM, EmptyStmt) {
|
|||
TEST(LLVM, EliminatedStmt) {
|
||||
BufHandle a("a", {1}, kFloat);
|
||||
|
||||
Tensor c = Compute("c", {{0, "m"}}, [&](const VarHandle& m) { return m; });
|
||||
Tensor c = Compute("c", {0}, [&](const VarHandle& m) { return m; });
|
||||
|
||||
LoopNest l({c});
|
||||
l.prepareForCodegen();
|
||||
|
|
@ -1405,10 +1400,7 @@ TEST(LLVM, SimpleReduction) {
|
|||
|
||||
BufHandle a("a", {1, M, N}, kFloat);
|
||||
|
||||
// TODO: why doesn't implicit vector<DimArg> work?
|
||||
std::vector<DimArg> axis = {DimArg(1)};
|
||||
std::vector<DimArg> reduce_axis = {DimArg(M), DimArg(N)};
|
||||
Tensor b = Reduce("sum", axis, Sum(), a, reduce_axis);
|
||||
Tensor b = Reduce("sum", {1}, Sum(), a, {M, N});
|
||||
LoopNest loop({b});
|
||||
|
||||
loop.prepareForCodegen();
|
||||
|
|
@ -1442,10 +1434,7 @@ TEST(LLVM, RFactorReduction) {
|
|||
|
||||
BufHandle a("a", {1, M, N}, kFloat);
|
||||
|
||||
// TODO: why doesn't implicit vector<DimArg> work?
|
||||
std::vector<DimArg> axis = {DimArg(1)};
|
||||
std::vector<DimArg> reduce_axis = {DimArg(M), DimArg(N)};
|
||||
Tensor b = Reduce("sum", axis, Sum(), a, reduce_axis);
|
||||
Tensor b = Reduce("sum", {1}, Sum(), a, {M, N});
|
||||
LoopNest loop({b});
|
||||
|
||||
std::vector<ForPtr> loops = loop.getLoopStmtsFor(b);
|
||||
|
|
@ -1490,7 +1479,7 @@ TEST(LLVM, RFactorVectorizedReduction) {
|
|||
|
||||
BufHandle a("a", {1, M, N}, kFloat);
|
||||
|
||||
Tensor b = Reduce("sum", {{1, "K"}}, Sum(), a, {{M, "M"}, {N, "N"}});
|
||||
Tensor b = Reduce("sum", {1}, Sum(), a, {M, N});
|
||||
LoopNest loopnest({b});
|
||||
std::vector<ForPtr> loops = loopnest.getLoopStmtsFor(b);
|
||||
// Reorder n and m loops
|
||||
|
|
@ -1536,10 +1525,9 @@ static void testSimpleParallel() {
|
|||
// parallel or sequential.
|
||||
const int M = 4;
|
||||
const int N = 6;
|
||||
Tensor f = Compute(
|
||||
"f", {{M, "m"}, {N, "n"}}, [](const VarHandle& m, const VarHandle& n) {
|
||||
return cast<float>(m + n);
|
||||
});
|
||||
Tensor f = Compute("f", {M, N}, [](const VarHandle& m, const VarHandle& n) {
|
||||
return cast<float>(m + n);
|
||||
});
|
||||
LoopNest loop_nest({f});
|
||||
auto const& loops = loop_nest.getLoopStmtsFor(f);
|
||||
ForPtr m = loops[0];
|
||||
|
|
@ -1588,20 +1576,14 @@ TEST(LLVM, CompositeParallel) {
|
|||
for (const auto test_cfg : c10::irange(test_count)) {
|
||||
int M = 5;
|
||||
int N = 7;
|
||||
Tensor t1 =
|
||||
Compute("t1", {{M, "M"}}, [](const VarHandle& m) { return m + 1.f; });
|
||||
Tensor t2 =
|
||||
Compute("t2", {{N, "N"}}, [](const VarHandle& n) { return n + 2.f; });
|
||||
Tensor t3 = Compute(
|
||||
"t3",
|
||||
{{M, "M"}, {N, "N"}},
|
||||
[=](const VarHandle& m, const VarHandle& n) {
|
||||
Tensor t1 = Compute("t1", {M}, [](const VarHandle& m) { return m + 1.f; });
|
||||
Tensor t2 = Compute("t2", {N}, [](const VarHandle& n) { return n + 2.f; });
|
||||
Tensor t3 =
|
||||
Compute("t3", {M, N}, [=](const VarHandle& m, const VarHandle& n) {
|
||||
return t1.load(m) * t2.load(n);
|
||||
});
|
||||
Tensor t4 = Compute(
|
||||
"t4",
|
||||
{{M, "M"}, {N, "N"}},
|
||||
[=](const VarHandle& m, const VarHandle& n) {
|
||||
Tensor t4 =
|
||||
Compute("t4", {M, N}, [=](const VarHandle& m, const VarHandle& n) {
|
||||
return t3.load(m, n) + m + n;
|
||||
});
|
||||
LoopNest loop_nest({t4}, {t1, t2, t3, t4});
|
||||
|
|
@ -1657,12 +1639,12 @@ TEST(LLVM, VectorizedGEMM) {
|
|||
BufHandle BP("B", {K, N}, kFloat);
|
||||
Tensor CT = Reduce(
|
||||
"gemm",
|
||||
{{M, "M"}, {N, "N"}},
|
||||
{M, N},
|
||||
Sum(),
|
||||
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
|
||||
return AP.load(m, k) * BP.load(k, n);
|
||||
},
|
||||
{{K, "K"}});
|
||||
{K});
|
||||
LoopNest loop({CT});
|
||||
|
||||
{
|
||||
|
|
@ -1735,10 +1717,9 @@ TEST(LLVM, CallRaw) {
|
|||
VarHandle N("N", kInt);
|
||||
BufHandle a("a", {M, N}, kFloat);
|
||||
BufHandle b("b", {N}, kFloat);
|
||||
Tensor c = Compute(
|
||||
"c", {{M, "i"}, {N, "j"}}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return a.load(i, j) + b.load(j);
|
||||
});
|
||||
Tensor c = Compute("c", {M, N}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return a.load(i, j) + b.load(j);
|
||||
});
|
||||
|
||||
LoopNest l({c});
|
||||
l.prepareForCodegen();
|
||||
|
|
@ -1776,7 +1757,7 @@ TEST(LLVM, CustomTarget) {
|
|||
BufHandle a("a", {M}, kFloat);
|
||||
BufHandle b("b", {M}, kFloat);
|
||||
BufHandle c("c", {M}, kFloat);
|
||||
Tensor d = Compute("d", {{M, "m"}}, [&](const VarHandle& m) {
|
||||
Tensor d = Compute("d", {M}, [&](const VarHandle& m) {
|
||||
return a.load(m) * b.load(m) + c.load(m);
|
||||
});
|
||||
LoopNest nest({d});
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -2696,13 +2696,13 @@ TEST(MemDependency, MemDependencyCheckerComputeAPI) {
|
|||
BufHandle b_buf("b", {5, 6}, kFloat);
|
||||
Tensor c = Compute(
|
||||
"broadcast_add",
|
||||
{{4, "m"}, {5, "n"}, {6, "k"}},
|
||||
{4, 5, 6},
|
||||
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
||||
return a_buf.load(m, n) + b_buf.load(n, k);
|
||||
});
|
||||
Tensor d = Compute(
|
||||
"d",
|
||||
{{4, "m"}, {5, "n"}, {6, "k"}},
|
||||
{4, 5, 6},
|
||||
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
||||
return c.load(m, n, k) + 1;
|
||||
});
|
||||
|
|
@ -2741,13 +2741,13 @@ TEST(MemDependency, MemDependencyCheckerComputeInline) {
|
|||
BufHandle b_buf("b", {5, 6}, kFloat);
|
||||
Tensor c = Compute(
|
||||
"broadcast_add",
|
||||
{{4, "m"}, {5, "n"}, {6, "k"}},
|
||||
{4, 5, 6},
|
||||
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
||||
return a_buf.load(m, n) + b_buf.load(n, k);
|
||||
});
|
||||
Tensor d = Compute(
|
||||
"d",
|
||||
{{4, "m"}, {5, "n"}, {6, "k"}},
|
||||
{4, 5, 6},
|
||||
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
||||
return c.load(m, n, k) + 1;
|
||||
});
|
||||
|
|
@ -2776,7 +2776,7 @@ TEST(MemDependency, MemDependencyCheckerComputeSplit) {
|
|||
BufHandle b_buf("b", {5, 6}, kFloat);
|
||||
Tensor c = Compute(
|
||||
"broadcast_add",
|
||||
{{4, "m"}, {5, "n"}, {6, "k"}},
|
||||
{4, 5, 6},
|
||||
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
||||
return a_buf.load(m, n) + b_buf.load(n, k);
|
||||
});
|
||||
|
|
@ -2822,7 +2822,7 @@ TEST(MemDependency, MemDependencyCheckerComputeReorder) {
|
|||
BufHandle b_buf("b", {5, 6}, kFloat);
|
||||
Tensor c = Compute(
|
||||
"broadcast_add",
|
||||
{{4, "m"}, {5, "n"}, {6, "k"}},
|
||||
{4, 5, 6},
|
||||
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
||||
return a_buf.load(m, n) + b_buf.load(n, k);
|
||||
});
|
||||
|
|
@ -2888,11 +2888,11 @@ TEST(MemDependency, MemDependencyCheckerComputeReduce) {
|
|||
|
||||
Tensor c = Compute(
|
||||
"scale",
|
||||
{{2, "l2"}, {3, "n1"}, {6, "m1"}},
|
||||
{2, 3, 6},
|
||||
[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
|
||||
return b.load(l, n, m) * a.load(l, n, m);
|
||||
});
|
||||
Tensor d = Reduce("sum", {{2, "l1"}}, Sum(), c, {{3, "n1"}, {6, "m1"}});
|
||||
Tensor d = Reduce("sum", {2}, Sum(), c, {3, 6});
|
||||
LoopNest l({d}, {c, d});
|
||||
|
||||
MemDependencyChecker analyzer({a.node(), b.node()}, {d.buf()});
|
||||
|
|
@ -2924,12 +2924,12 @@ TEST(MemDependency, MemDependencyCheckerComputeGEMM) {
|
|||
BufHandle BP("B", {K, N}, kFloat);
|
||||
Tensor CT = Reduce(
|
||||
"gemm",
|
||||
{{M, "M"}, {N, "N"}},
|
||||
{M, N},
|
||||
Sum(),
|
||||
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
|
||||
return AP.load(m, k) * BP.load(k, n);
|
||||
},
|
||||
{{K, "K"}});
|
||||
{K});
|
||||
LoopNest loop({CT});
|
||||
|
||||
{
|
||||
|
|
|
|||
|
|
@ -95,30 +95,24 @@ TEST(MemPlanning, SameBufSizeMemReuse) {
|
|||
|
||||
Tensor CT = Reduce(
|
||||
"gemm",
|
||||
{{M, "M"}, {N, "N"}},
|
||||
{M, N},
|
||||
Sum(),
|
||||
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
|
||||
return AP.load(m, k) * BP.load(k, n);
|
||||
},
|
||||
{{K, "K"}});
|
||||
Tensor DT = Compute(
|
||||
"relu",
|
||||
{{M, "M"}, {N, "N"}},
|
||||
[&](const ExprHandle& m, const ExprHandle& n) {
|
||||
{K});
|
||||
Tensor DT =
|
||||
Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
auto zero = Cast::make(CT.buf()->dtype(), 0);
|
||||
return CompareSelect::make(
|
||||
CT.load(m, n), zero, zero, CT.load(m, n), kLT);
|
||||
});
|
||||
Tensor ET = Compute(
|
||||
"add",
|
||||
{{M, "M"}, {N, "N"}},
|
||||
[&](const ExprHandle& m, const ExprHandle& n) {
|
||||
Tensor ET =
|
||||
Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return DT.load(m, n) + DT.load(m, n);
|
||||
});
|
||||
Tensor FT = Compute(
|
||||
"mul",
|
||||
{{M, "M"}, {N, "N"}},
|
||||
[&](const ExprHandle& m, const ExprHandle& n) {
|
||||
Tensor FT =
|
||||
Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return ET.load(m, n) * ET.load(m, n);
|
||||
});
|
||||
auto stmt = Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()});
|
||||
|
|
@ -188,36 +182,28 @@ TEST(MemPlanning, SameBufSizeMultiMemReuses) {
|
|||
|
||||
Tensor CT = Reduce(
|
||||
"gemm",
|
||||
{{M, "M"}, {N, "N"}},
|
||||
{M, N},
|
||||
Sum(),
|
||||
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
|
||||
return AP.load(m, k) * BP.load(k, n);
|
||||
},
|
||||
{{K, "K"}});
|
||||
Tensor DT = Compute(
|
||||
"relu",
|
||||
{{M, "M"}, {N, "N"}},
|
||||
[&](const ExprHandle& m, const ExprHandle& n) {
|
||||
{K});
|
||||
Tensor DT =
|
||||
Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
auto zero = Cast::make(CT.buf()->dtype(), 0);
|
||||
return CompareSelect::make(
|
||||
CT.load(m, n), zero, zero, CT.load(m, n), kLT);
|
||||
});
|
||||
Tensor ET = Compute(
|
||||
"add",
|
||||
{{M, "M"}, {N, "N"}},
|
||||
[&](const ExprHandle& m, const ExprHandle& n) {
|
||||
Tensor ET =
|
||||
Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return DT.load(m, n) + DT.load(m, n);
|
||||
});
|
||||
Tensor FT = Compute(
|
||||
"mul",
|
||||
{{M, "M"}, {N, "N"}},
|
||||
[&](const ExprHandle& m, const ExprHandle& n) {
|
||||
Tensor FT =
|
||||
Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return ET.load(m, n) * ET.load(m, n);
|
||||
});
|
||||
Tensor GT = Compute(
|
||||
"sub",
|
||||
{{M, "M"}, {N, "N"}},
|
||||
[&](const ExprHandle& m, const ExprHandle& n) {
|
||||
Tensor GT =
|
||||
Compute("sub", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return FT.load(m, n) - ET.load(m, n);
|
||||
});
|
||||
|
||||
|
|
@ -296,42 +282,32 @@ TEST(MemPlanning, SameBufSizeMultiMemReusesOfOneBuf) {
|
|||
|
||||
Tensor CT = Reduce(
|
||||
"gemm",
|
||||
{{M, "M"}, {N, "N"}},
|
||||
{M, N},
|
||||
Sum(),
|
||||
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
|
||||
return AP.load(m, k) * BP.load(k, n);
|
||||
},
|
||||
{{K, "K"}});
|
||||
Tensor DT = Compute(
|
||||
"relu",
|
||||
{{M, "M"}, {N, "N"}},
|
||||
[&](const ExprHandle& m, const ExprHandle& n) {
|
||||
{K});
|
||||
Tensor DT =
|
||||
Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
auto zero = Cast::make(CT.buf()->dtype(), 0);
|
||||
return CompareSelect::make(
|
||||
CT.load(m, n), zero, zero, CT.load(m, n), kLT);
|
||||
});
|
||||
Tensor ET = Compute(
|
||||
"add",
|
||||
{{M, "M"}, {N, "N"}},
|
||||
[&](const ExprHandle& m, const ExprHandle& n) {
|
||||
Tensor ET =
|
||||
Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return DT.load(m, n) + DT.load(m, n);
|
||||
});
|
||||
Tensor FT = Compute(
|
||||
"mul",
|
||||
{{M, "M"}, {N, "N"}},
|
||||
[&](const ExprHandle& m, const ExprHandle& n) {
|
||||
Tensor FT =
|
||||
Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return ET.load(m, n) * ET.load(m, n);
|
||||
});
|
||||
Tensor GT = Compute(
|
||||
"sub",
|
||||
{{M, "M"}, {N, "N"}},
|
||||
[&](const ExprHandle& m, const ExprHandle& n) {
|
||||
Tensor GT =
|
||||
Compute("sub", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return FT.load(m, n) - 1;
|
||||
});
|
||||
Tensor HT = Compute(
|
||||
"div",
|
||||
{{M, "M"}, {N, "N"}},
|
||||
[&](const ExprHandle& m, const ExprHandle& n) {
|
||||
Tensor HT =
|
||||
Compute("div", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return GT.load(m, n) / 2;
|
||||
});
|
||||
|
||||
|
|
@ -418,30 +394,24 @@ TEST(MemPlanning, SmallerBufSizeNonMemReuse) {
|
|||
|
||||
Tensor CT = Reduce(
|
||||
"gemm",
|
||||
{{M, "M"}, {N, "N"}},
|
||||
{M, N},
|
||||
Sum(),
|
||||
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
|
||||
return AP.load(m, k) * BP.load(k, n);
|
||||
},
|
||||
{{K, "K"}});
|
||||
Tensor DT = Compute(
|
||||
"relu",
|
||||
{{M, "M"}, {N, "N"}},
|
||||
[&](const ExprHandle& m, const ExprHandle& n) {
|
||||
{K});
|
||||
Tensor DT =
|
||||
Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
auto zero = Cast::make(CT.buf()->dtype(), 0);
|
||||
return CompareSelect::make(
|
||||
CT.load(m, n), zero, zero, CT.load(m, n), kLT);
|
||||
});
|
||||
Tensor ET = Compute(
|
||||
"add",
|
||||
{{M * 2, "EM"}, {N * 2, "EN"}},
|
||||
[&](const ExprHandle& em, const ExprHandle& en) {
|
||||
"add", {M * 2, N * 2}, [&](const ExprHandle& em, const ExprHandle& en) {
|
||||
return DT.load(em / 2, en / 2) + DT.load(em / 2, en / 2);
|
||||
});
|
||||
Tensor FT = Compute(
|
||||
"mul",
|
||||
{{M * 2, "FM"}, {N * 2, "FN"}},
|
||||
[&](const ExprHandle& fm, const ExprHandle& fn) {
|
||||
"mul", {M * 2, N * 2}, [&](const ExprHandle& fm, const ExprHandle& fn) {
|
||||
return ET.load(fm, fn) * ET.load(fm, fn);
|
||||
});
|
||||
auto stmt = Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()});
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ TEST(Reductions, ReduceSum0D_1) {
|
|||
|
||||
std::vector<float> out(M, -1.f);
|
||||
|
||||
Tensor c = Reduce("sum", {{M, "m"}}, Sum(), b, {});
|
||||
Tensor c = Reduce("sum", {M}, Sum(), b, {});
|
||||
LoopNest loop({c});
|
||||
loop.prepareForCodegen();
|
||||
StmtPtr s = loop.root_stmt();
|
||||
|
|
@ -80,7 +80,7 @@ TEST(Reductions, ReduceSum1D) {
|
|||
|
||||
std::vector<float> out(1, -1.f);
|
||||
|
||||
Tensor c = Reduce("sum", {}, Sum(), b, {{10, "m"}});
|
||||
Tensor c = Reduce("sum", {}, Sum(), b, {10});
|
||||
LoopNest loop({c});
|
||||
loop.prepareForCodegen();
|
||||
StmtPtr s = loop.root_stmt();
|
||||
|
|
@ -109,7 +109,7 @@ TEST(Reductions, ReduceSum2D) {
|
|||
|
||||
std::vector<float> out(M, -1.f);
|
||||
|
||||
Tensor c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}});
|
||||
Tensor c = Reduce("sum", {M}, Sum(), b, {N});
|
||||
LoopNest loop({c});
|
||||
loop.prepareForCodegen();
|
||||
StmtPtr s = loop.root_stmt();
|
||||
|
|
@ -138,7 +138,7 @@ TEST(Reductions, ReduceSum3D) {
|
|||
|
||||
BufHandle b("b", {2, 3, m}, kFloat);
|
||||
|
||||
Tensor c = Reduce("sum", {{2, "l"}, {3, "n"}}, Sum(), b, {{m, "m"}});
|
||||
Tensor c = Reduce("sum", {2, 3}, Sum(), b, {m});
|
||||
LoopNest loop({c});
|
||||
loop.prepareForCodegen();
|
||||
StmtPtr s = loop.root_stmt();
|
||||
|
|
@ -168,7 +168,7 @@ TEST(Reductions, ReduceSum3D) {
|
|||
ASSERT_EQ(cData[i], expected);
|
||||
}
|
||||
|
||||
Tensor d = Reduce("sum2", {{2, "l"}}, Sum(), b, {{3, "n"}, {m, "m"}});
|
||||
Tensor d = Reduce("sum2", {2}, Sum(), b, {3, m});
|
||||
LoopNest loop2({d});
|
||||
loop2.prepareForCodegen();
|
||||
StmtPtr s2 = loop2.root_stmt();
|
||||
|
|
@ -186,7 +186,7 @@ TEST(Reductions, ReduceSum3D) {
|
|||
|
||||
// This is the same as just reducing the original result across that axis.
|
||||
BufHandle c_buf(c.buf());
|
||||
Tensor e = Reduce("sum3", {{2, "l"}}, Sum(), c_buf, {{3, "m"}});
|
||||
Tensor e = Reduce("sum3", {2}, Sum(), c_buf, {3});
|
||||
LoopNest loop3({e});
|
||||
loop3.prepareForCodegen();
|
||||
StmtPtr s3 = loop3.root_stmt();
|
||||
|
|
@ -210,12 +210,7 @@ TEST(Reductions, ReduceSum10D) {
|
|||
std::vector<float> in(InputSize, 1.f);
|
||||
std::vector<float> out(OutputSize, -1.f);
|
||||
|
||||
Tensor c = Reduce(
|
||||
"sum",
|
||||
{{2, "a"}, {3, "b"}, {2, "c"}, {3, "d"}, {2, "e"}},
|
||||
Sum(),
|
||||
in_,
|
||||
{{3, "f"}, {2, "g"}, {3, "h"}, {2, "i"}, {3, "j"}});
|
||||
Tensor c = Reduce("sum", {2, 3, 2, 3, 2}, Sum(), in_, {3, 2, 3, 2, 3});
|
||||
LoopNest loop({c});
|
||||
loop.prepareForCodegen();
|
||||
StmtPtr s = loop.root_stmt();
|
||||
|
|
@ -250,7 +245,7 @@ TEST(Reductions, ReduceProduct) {
|
|||
Reducer product(
|
||||
ExprHandle(1.f), [](ExprHandle a, ExprHandle b) { return a * b; });
|
||||
|
||||
Tensor c = Reduce("product", {{M, "m"}}, product, b, {{N, "n"}});
|
||||
Tensor c = Reduce("product", {M}, product, b, {N});
|
||||
LoopNest loop({c});
|
||||
loop.prepareForCodegen();
|
||||
StmtPtr s = loop.root_stmt();
|
||||
|
|
@ -281,7 +276,7 @@ TEST(Reductions, ReduceMax) {
|
|||
in[j] = j;
|
||||
}
|
||||
|
||||
Tensor dm1 = Reduce("max", {}, Maximum(kFloat), in_, {{10, "m"}});
|
||||
Tensor dm1 = Reduce("max", {}, Maximum(kFloat), in_, {10});
|
||||
|
||||
LoopNest loop({dm1});
|
||||
loop.prepareForCodegen();
|
||||
|
|
@ -296,7 +291,7 @@ TEST(Reductions, ReduceMax) {
|
|||
BufHandle in2_("b", {2, 5}, kFloat);
|
||||
std::vector<float> out2(2, -1.f);
|
||||
|
||||
Tensor m2d = Reduce("max", {{2, "n"}}, Maximum(kFloat), in2_, {{5, "m"}});
|
||||
Tensor m2d = Reduce("max", {2}, Maximum(kFloat), in2_, {5});
|
||||
|
||||
LoopNest loop2({m2d});
|
||||
loop2.prepareForCodegen();
|
||||
|
|
@ -326,7 +321,7 @@ TEST(Reductions, ReduceMinCustomInitializer) {
|
|||
{},
|
||||
Minimum(ExprHandle(minInit)),
|
||||
[&](ParameterList& v) { return in_.load(v); },
|
||||
{{10, "m"}});
|
||||
{10});
|
||||
|
||||
LoopNest loop({min});
|
||||
loop.prepareForCodegen();
|
||||
|
|
@ -357,12 +352,12 @@ TEST(Reductions, ReduceAnyAll) {
|
|||
|
||||
Tensor any = Reduce(
|
||||
"anyEqual",
|
||||
{{4, "i"}},
|
||||
{4},
|
||||
anyEqSV,
|
||||
[&](const auto& i, const auto& j) {
|
||||
return CompareSelect::make(b.load(i, j), searchValue, kEQ);
|
||||
},
|
||||
{{10, "j"}});
|
||||
{10});
|
||||
|
||||
LoopNest loop({any});
|
||||
loop.prepareForCodegen();
|
||||
|
|
@ -400,12 +395,12 @@ TEST(Reductions, ReduceAnyAll) {
|
|||
|
||||
Tensor allGreaterThan = Reduce(
|
||||
"allGreaterThan",
|
||||
{{4, "i"}},
|
||||
{4},
|
||||
allGTSV,
|
||||
[&](const auto& i, const auto& j) {
|
||||
return CompareSelect::make(b.load(i, j), searchValue, kGT);
|
||||
},
|
||||
{{10, "j"}});
|
||||
{10});
|
||||
|
||||
LoopNest loop2({allGreaterThan});
|
||||
loop2.prepareForCodegen();
|
||||
|
|
@ -448,12 +443,12 @@ TEST(Reductions, ReduceMatmul2D) {
|
|||
|
||||
Tensor mm = Reduce(
|
||||
"mm",
|
||||
{{3, "m"}, {3, "n"}},
|
||||
{3, 3},
|
||||
Sum(),
|
||||
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
|
||||
return tA.load(m, k) * tB.load(k, n);
|
||||
},
|
||||
{{2, "k"}});
|
||||
{2});
|
||||
|
||||
LoopNest loop({mm});
|
||||
loop.prepareForCodegen();
|
||||
|
|
@ -480,10 +475,10 @@ TEST(Reductions, ReduceRfactorLike) {
|
|||
std::vector<float> in_rf_(10, -2.f);
|
||||
std::vector<float> out(1, -1.f);
|
||||
|
||||
Tensor l1 = Reduce("l1", {{10, "i"}}, Sum(), in, {{10, "j"}});
|
||||
Tensor l1 = Reduce("l1", {10}, Sum(), in, {10});
|
||||
BufHandle in_rf(l1.buf());
|
||||
|
||||
Tensor l2 = Reduce("l2", {}, Sum(), in_rf, {{10, "i"}});
|
||||
Tensor l2 = Reduce("l2", {}, Sum(), in_rf, {10});
|
||||
|
||||
LoopNest loop({l1, l2});
|
||||
loop.prepareForCodegen();
|
||||
|
|
@ -503,11 +498,9 @@ TEST(Reductions, ReduceAsProducer) {
|
|||
BufHandle a("a", {2, 3}, kFloat);
|
||||
BufHandle b("b", {2, 3, m}, kFloat);
|
||||
|
||||
Tensor c = Reduce("sum", {{2, "l1"}, {3, "n1"}}, Sum(), b, {{m, "m1"}});
|
||||
Tensor d = Compute(
|
||||
"scale",
|
||||
{{2, "l2"}, {3, "n1"}},
|
||||
[&](const VarHandle& l, const VarHandle& n) {
|
||||
Tensor c = Reduce("sum", {2, 3}, Sum(), b, {m});
|
||||
Tensor d =
|
||||
Compute("scale", {2, 3}, [&](const VarHandle& l, const VarHandle& n) {
|
||||
return c.load(l, n) * a.load(l, n);
|
||||
});
|
||||
LoopNest loop({d}, {c, d});
|
||||
|
|
@ -548,11 +541,11 @@ TEST(Reductions, ReduceAsConsumer) {
|
|||
|
||||
Tensor c = Compute(
|
||||
"scale",
|
||||
{{2, "l2"}, {3, "n1"}, {m, "m1"}},
|
||||
{2, 3, m},
|
||||
[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
|
||||
return b.load(l, n, m) * a.load(l, n, m);
|
||||
});
|
||||
Tensor d = Reduce("sum", {{2, "l1"}}, Sum(), c, {{3, "n1"}, {m, "m1"}});
|
||||
Tensor d = Reduce("sum", {2}, Sum(), c, {3, m});
|
||||
LoopNest loop({d}, {c, d});
|
||||
loop.prepareForCodegen();
|
||||
StmtPtr s = loop.root_stmt();
|
||||
|
|
@ -599,7 +592,7 @@ TEST(Reductions, SplitReduceAxis) {
|
|||
}
|
||||
std::vector<float> out(16, -1.f);
|
||||
|
||||
Tensor tensor = Reduce("sum", {{16, "m"}}, Sum(), in, {{8, "n"}});
|
||||
Tensor tensor = Reduce("sum", {16}, Sum(), in, {8});
|
||||
LoopNest l({tensor});
|
||||
std::vector<ForPtr> loops = l.getLoopStmtsFor(tensor);
|
||||
LoopNest::splitWithTail(loops[1], 2);
|
||||
|
|
@ -627,7 +620,7 @@ TEST(Reductions, SplitNonReduceAxis) {
|
|||
}
|
||||
}
|
||||
std::vector<float> out(16, -1.f);
|
||||
Tensor tensor = Reduce("sum", {{16, "m"}}, Sum(), in, {{8, "n"}});
|
||||
Tensor tensor = Reduce("sum", {16}, Sum(), in, {8});
|
||||
LoopNest l({tensor});
|
||||
std::vector<ForPtr> loops = l.getLoopStmtsFor(tensor);
|
||||
LoopNest::splitWithTail(loops[0], 2);
|
||||
|
|
@ -657,14 +650,14 @@ TEST(Reductions, ReorderedReductionInitializer) {
|
|||
BufHandle in("in", {1, 12, 6}, kFloat);
|
||||
std::vector<float> in_(12 * 6, 1.f);
|
||||
|
||||
Tensor tensor_ = Reduce("sum", {{1, "k"}, {12, "n"}}, Sum(), in, {{6, "m"}});
|
||||
Tensor tensor_ = Reduce("sum", {1, 12}, Sum(), in, {6});
|
||||
LoopNest l_({tensor_});
|
||||
|
||||
l_.prepareForCodegen();
|
||||
StmtPtr s_ = Stmt::clone(l_.root_stmt());
|
||||
s_ = IRSimplifier::simplify(s_);
|
||||
|
||||
Tensor tensor = Reduce("sum", {{1, "k"}, {12, "n"}}, Sum(), in, {{6, "m"}});
|
||||
Tensor tensor = Reduce("sum", {1, 12}, Sum(), in, {6});
|
||||
LoopNest l({tensor});
|
||||
|
||||
auto loops = l.getLoopStmtsFor(tensor);
|
||||
|
|
@ -709,7 +702,7 @@ TEST(Reductions, ReduceRfactor) {
|
|||
|
||||
std::vector<float> out(1, -1.f);
|
||||
|
||||
Tensor c = Reduce("sum", {}, Sum(), b, {{m, "m"}, {n, "n"}});
|
||||
Tensor c = Reduce("sum", {}, Sum(), b, {m, n});
|
||||
LoopNest loop({c});
|
||||
std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
|
||||
auto c_body = loop.getAllWritesToBuf(c.buf())[1];
|
||||
|
|
@ -742,7 +735,7 @@ TEST(Reductions, Reduce3DRfactorInner) {
|
|||
|
||||
std::vector<float> out(1, -1.f);
|
||||
|
||||
Tensor c = Reduce("sum", {}, Sum(), b, {{m, "m"}, {n, "n"}, {k, "k"}});
|
||||
Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k});
|
||||
LoopNest loop({c});
|
||||
std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
|
||||
auto c_body = loop.getAllWritesToBuf(c.buf())[1];
|
||||
|
|
@ -775,7 +768,7 @@ TEST(Reductions, Reduce3DRfactorOuter) {
|
|||
|
||||
std::vector<float> out(1, -1.f);
|
||||
|
||||
Tensor c = Reduce("sum", {}, Sum(), b, {{m, "m"}, {n, "n"}, {k, "k"}});
|
||||
Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k});
|
||||
LoopNest loop({c});
|
||||
std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
|
||||
auto c_body = loop.getAllWritesToBuf(c.buf())[1];
|
||||
|
|
@ -799,12 +792,7 @@ TEST(Reductions, ReduceRepeatedInternalRfactor) {
|
|||
std::vector<float> out(1, -1.f);
|
||||
std::vector<float> ref(1, -1.f);
|
||||
|
||||
Tensor c = Reduce(
|
||||
"sum",
|
||||
{},
|
||||
Sum(),
|
||||
in_,
|
||||
{{2, "a"}, {3, "b"}, {4, "c"}, {5, "d"}, {6, "e"}});
|
||||
Tensor c = Reduce("sum", {}, Sum(), in_, {2, 3, 4, 5, 6});
|
||||
LoopNest orig_loop({c});
|
||||
|
||||
// Try rfactoring N outer loops
|
||||
|
|
@ -850,7 +838,7 @@ TEST(Reductions, ReduceSplitTail) {
|
|||
for (const auto i : c10::irange(3)) {
|
||||
std::vector<float> out(M, -1.f);
|
||||
|
||||
Tensor c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}});
|
||||
Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
|
||||
LoopNest loop({c});
|
||||
std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
|
||||
LoopNest::splitWithTail(loops[i], 8);
|
||||
|
|
@ -880,7 +868,7 @@ TEST(Reductions, ReduceSplitNoTail) {
|
|||
for (const auto i : c10::irange(3)) {
|
||||
std::vector<float> out(M, -1.f);
|
||||
|
||||
Tensor c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}});
|
||||
Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
|
||||
LoopNest loop({c});
|
||||
std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
|
||||
LoopNest::splitWithTail(loops[i], 5);
|
||||
|
|
@ -912,7 +900,7 @@ TEST(Reductions, ReduceOverSplitTail) {
|
|||
for (const auto i : c10::irange(3)) {
|
||||
std::vector<float> out(M, -1.f);
|
||||
|
||||
Tensor c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}});
|
||||
Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
|
||||
LoopNest loop({c});
|
||||
std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
|
||||
LoopNest::splitWithTail(loops[i], 16);
|
||||
|
|
@ -943,7 +931,7 @@ TEST(Reductions, ReduceSplitMask) {
|
|||
for (const auto i : c10::irange(3)) {
|
||||
std::vector<float> out(M, -1.f);
|
||||
|
||||
Tensor c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}});
|
||||
Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
|
||||
LoopNest loop({c});
|
||||
std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
|
||||
LoopNest::splitWithMask(loops[i], 8);
|
||||
|
|
@ -973,7 +961,7 @@ TEST(Reductions, ReduceSplitNoMask) {
|
|||
for (const auto i : c10::irange(3)) {
|
||||
std::vector<float> out(M, -1.f);
|
||||
|
||||
Tensor c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}});
|
||||
Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
|
||||
LoopNest loop({c});
|
||||
std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
|
||||
LoopNest::splitWithMask(loops[i], 5);
|
||||
|
|
@ -1004,7 +992,7 @@ TEST(Reductions, ReduceOverSplitMask) {
|
|||
for (const auto i : c10::irange(3)) {
|
||||
std::vector<float> out(M, -1.f);
|
||||
|
||||
Tensor c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}});
|
||||
Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
|
||||
LoopNest loop({c});
|
||||
std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
|
||||
LoopNest::splitWithMask(loops[i], 16);
|
||||
|
|
@ -1038,7 +1026,7 @@ TEST(Reductions, ReduceSplitRfactor) {
|
|||
|
||||
std::vector<float> out(M, -1.f);
|
||||
|
||||
Tensor c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}});
|
||||
Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
|
||||
LoopNest loop({c});
|
||||
std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
|
||||
LoopNest::splitWithTail(loops[2], SPLIT_FACTOR);
|
||||
|
|
@ -1078,7 +1066,7 @@ TEST(Reductions, ReduceOverSplitRfactor) {
|
|||
|
||||
std::vector<float> out(1, -1.f);
|
||||
|
||||
Tensor c = Reduce("sum", {}, Sum(), b, {{N, "n"}, {K, "k"}});
|
||||
Tensor c = Reduce("sum", {}, Sum(), b, {N, K});
|
||||
LoopNest loop({c});
|
||||
std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
|
|
@ -1128,10 +1116,9 @@ TEST(Reductions, ReduceInlineReduction) {
|
|||
BufHandle a_buf("a", {M}, kFloat);
|
||||
BufHandle b_buf("b", {M, N, K}, kFloat);
|
||||
|
||||
Tensor x = Reduce("x", {{M, "m1"}}, Sum(), b_buf, {{N, "n1"}, {K, "k1"}});
|
||||
Tensor y = Compute("y", {{M, "m2"}}, [&](const VarHandle& m) {
|
||||
return a_buf.load(m) + x.load(m);
|
||||
});
|
||||
Tensor x = Reduce("x", {M}, Sum(), b_buf, {N, K});
|
||||
Tensor y = Compute(
|
||||
"y", {M}, [&](const VarHandle& m) { return a_buf.load(m) + x.load(m); });
|
||||
|
||||
PaddedBuffer<float> a_v(M);
|
||||
PaddedBuffer<float> b_v(M, N, K);
|
||||
|
|
@ -1162,11 +1149,11 @@ TEST(Reductions, ReduceInlineConsumer) {
|
|||
|
||||
Tensor x = Compute(
|
||||
"x",
|
||||
{{M, "m1"}, {N, "n1"}, {K, "k1"}},
|
||||
{M, N, K},
|
||||
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
||||
return a_buf.load(m, n, k) + b_buf.load(m, n, k);
|
||||
});
|
||||
Tensor y = Reduce("y", {{M, "m2"}}, Sum(), x, {{N, "n2"}, {K, "k2"}});
|
||||
Tensor y = Reduce("y", {M}, Sum(), x, {N, K});
|
||||
|
||||
PaddedBuffer<float> a_v(M, N, K);
|
||||
PaddedBuffer<float> b_v(M, N, K);
|
||||
|
|
@ -1215,7 +1202,7 @@ TEST(Reductions, ReduceInlineReducerInternal) {
|
|||
|
||||
Tensor x = Compute(
|
||||
"x",
|
||||
{{M, "m1"}, {N, "n1"}, {K, "k1"}},
|
||||
{M, N, K},
|
||||
[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
|
||||
return a_buf.load(m, n, k) + b_buf.load(m, n, k);
|
||||
});
|
||||
|
|
@ -1223,7 +1210,7 @@ TEST(Reductions, ReduceInlineReducerInternal) {
|
|||
Reducer minimum(ExprHandle(0.f), [&](ExprHandle a, ExprHandle b) {
|
||||
return Add::make(ExprHandle(1.f), Min::make(a, b, false));
|
||||
});
|
||||
Tensor y = Reduce("y", {{M, "m2"}}, minimum, x, {{N, "n2"}, {K, "k2"}});
|
||||
Tensor y = Reduce("y", {M}, minimum, x, {N, K});
|
||||
|
||||
PaddedBuffer<float> a_v(M, N, K);
|
||||
PaddedBuffer<float> b_v(M, N, K);
|
||||
|
|
@ -1272,26 +1259,28 @@ TEST(Reductions, ReductionCacheAccessesOperatorAxis) {
|
|||
|
||||
Tensor c = Compute(
|
||||
"scale",
|
||||
{{L, "l2"}, {N, "n1"}, {M, "m1"}},
|
||||
{L, N, M},
|
||||
[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
|
||||
return b.load(l, n, m) * a.load(l, n, m);
|
||||
});
|
||||
Tensor d = Reduce("sum", {{L, "l1"}}, Sum(), c, {{N, "n1"}, {M, "m1"}});
|
||||
Tensor d = Reduce("sum", {L}, Sum(), c, {N, M});
|
||||
|
||||
Tensor e = Compute("scale", {{L, "l"}}, [&](const VarHandle& l) {
|
||||
Tensor e = Compute("scale", {L}, [&](const VarHandle& l) {
|
||||
return b.load(0, 0, l) * d.load(l);
|
||||
});
|
||||
|
||||
LoopNest l({e}, {c, d, e});
|
||||
LoopNest l_before(l);
|
||||
l_before.prepareForCodegen();
|
||||
SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e});
|
||||
SimpleIREvaluator cg_before(
|
||||
LoopNest::sanitizeNames(l_before.root_stmt()), {a, b, e});
|
||||
|
||||
StmtPtr d_loop = l.getLoopStmtsFor(d)[0];
|
||||
l.cacheAccesses(d.buf(), "d_local", d_loop);
|
||||
l.prepareForCodegen();
|
||||
|
||||
StmtPtr result = IRSimplifier::simplify(l.root_stmt());
|
||||
StmtPtr result =
|
||||
LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
|
||||
SimpleIREvaluator cg_after(result, {a, b, e});
|
||||
|
||||
std::ostringstream oss;
|
||||
|
|
@ -1299,16 +1288,16 @@ TEST(Reductions, ReductionCacheAccessesOperatorAxis) {
|
|||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
#CHECK: Allocate(d_local); // dtype=float, dims=[4]
|
||||
#CHECK: for (int l1
|
||||
#CHECK: d_local[l1] = 0.f
|
||||
#CHECK: for (int n1
|
||||
#CHECK: for (int m1
|
||||
#CHECK: d_local[l1] = (d_local[l1]) + (scale[
|
||||
#CHECK: for (int i_2
|
||||
#CHECK: d_local[i_2] = 0.f
|
||||
#CHECK: for (int
|
||||
#CHECK: for (int
|
||||
#CHECK: d_local[i_2] = (d_local[i_2]) + (scale[
|
||||
#CHECK: }
|
||||
#CHECK: }
|
||||
#CHECK: }
|
||||
#CHECK: for (int i
|
||||
#CHECK: sum[i] = d_local[i]
|
||||
#CHECK: for (int i_3
|
||||
#CHECK: sum[i_3] = d_local[i_3]
|
||||
#CHECK: Free(d_local);
|
||||
#CHECK-NOT: d_local
|
||||
)IR";
|
||||
|
|
@ -1347,13 +1336,13 @@ TEST(Reductions, ReductionCacheAccessesOuterReduceAxis) {
|
|||
|
||||
Tensor c = Compute(
|
||||
"scale",
|
||||
{{L, "l2"}, {N, "n1"}, {M, "m1"}},
|
||||
{L, N, M},
|
||||
[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
|
||||
return b.load(l, n, m) * a.load(l, n, m);
|
||||
});
|
||||
Tensor d = Reduce("sum", {{L, "l1"}}, Sum(), c, {{N, "n1"}, {M, "m1"}});
|
||||
Tensor d = Reduce("sum", {L}, Sum(), c, {N, M});
|
||||
|
||||
Tensor e = Compute("scale", {{L, "l"}}, [&](const VarHandle& l) {
|
||||
Tensor e = Compute("scale", {L}, [&](const VarHandle& l) {
|
||||
return b.load(0, 0, l) * d.load(l);
|
||||
});
|
||||
|
||||
|
|
@ -1366,7 +1355,8 @@ TEST(Reductions, ReductionCacheAccessesOuterReduceAxis) {
|
|||
l.cacheAccesses(d.buf(), "d_local", d_loop);
|
||||
l.prepareForCodegen();
|
||||
|
||||
StmtPtr result = IRSimplifier::simplify(l.root_stmt());
|
||||
StmtPtr result =
|
||||
LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
|
||||
SimpleIREvaluator cg_after(result, {a, b, e});
|
||||
|
||||
std::ostringstream oss;
|
||||
|
|
@ -1374,14 +1364,14 @@ TEST(Reductions, ReductionCacheAccessesOuterReduceAxis) {
|
|||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
#CHECK: Allocate(d_local); // dtype=float, dims=[1]
|
||||
#CHECK: sum[l1] = 0
|
||||
#CHECK: d_local[0] = sum[l1]
|
||||
#CHECK: for (int n1
|
||||
#CHECK: for (int m1
|
||||
#CHECK: sum[i_1] = 0
|
||||
#CHECK: d_local[0] = sum[i_1]
|
||||
#CHECK: for (int j_1
|
||||
#CHECK: for (int k_1
|
||||
#CHECK: d_local[0] = (d_local[0]) + (scale[
|
||||
#CHECK: }
|
||||
#CHECK: }
|
||||
#CHECK: sum[l1] = d_local[0]
|
||||
#CHECK: sum[i_1] = d_local[0]
|
||||
#CHECK: Free(d_local);
|
||||
#CHECK-NOT: d_local
|
||||
)IR";
|
||||
|
|
@ -1420,13 +1410,13 @@ TEST(Reductions, ReductionCacheAccessesInnerReduceAxis) {
|
|||
|
||||
Tensor c = Compute(
|
||||
"scale",
|
||||
{{L, "l2"}, {N, "n1"}, {M, "m1"}},
|
||||
{L, N, M},
|
||||
[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
|
||||
return b.load(l, n, m) * a.load(l, n, m);
|
||||
});
|
||||
Tensor d = Reduce("sum", {{L, "l1"}}, Sum(), c, {{N, "n1"}, {M, "m1"}});
|
||||
Tensor d = Reduce("sum", {L}, Sum(), c, {N, M});
|
||||
|
||||
Tensor e = Compute("scale", {{L, "l"}}, [&](const VarHandle& l) {
|
||||
Tensor e = Compute("scale", {L}, [&](const VarHandle& l) {
|
||||
return b.load(0, 0, l) * d.load(l);
|
||||
});
|
||||
|
||||
|
|
@ -1439,7 +1429,8 @@ TEST(Reductions, ReductionCacheAccessesInnerReduceAxis) {
|
|||
l.cacheAccesses(d.buf(), "d_local", d_loop);
|
||||
l.prepareForCodegen();
|
||||
|
||||
StmtPtr result = IRSimplifier::simplify(l.root_stmt());
|
||||
StmtPtr result =
|
||||
LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
|
||||
SimpleIREvaluator cg_after(result, {a, b, e});
|
||||
|
||||
std::ostringstream oss;
|
||||
|
|
@ -1447,13 +1438,13 @@ TEST(Reductions, ReductionCacheAccessesInnerReduceAxis) {
|
|||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
#CHECK: Allocate(d_local); // dtype=float, dims=[1]
|
||||
#CHECK: sum[l1] = 0
|
||||
#CHECK: for (int n1
|
||||
#CHECK: sum[i_1] = 0
|
||||
#CHECK: for (int
|
||||
#CHECK: d_local[0] = 0
|
||||
#CHECK: for (int m1
|
||||
#CHECK: for (int
|
||||
#CHECK: d_local[0] = (d_local[0]) + (scale[
|
||||
#CHECK: }
|
||||
#CHECK: sum[l1] = (sum[l1]) + (d_local[0])
|
||||
#CHECK: sum[i_1] = (sum[i_1]) + (d_local[0])
|
||||
#CHECK: }
|
||||
#CHECK: Free(d_local);
|
||||
#CHECK-NOT: d_local
|
||||
|
|
@ -1489,13 +1480,13 @@ TEST(Reductions, ReductionCacheBodyAccess) {
|
|||
|
||||
Tensor c = Compute(
|
||||
"scale",
|
||||
{{24, "l2"}, {32, "n1"}, {12, "m1"}},
|
||||
{24, 32, 12},
|
||||
[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
|
||||
return b.load(l, n, m) * a.load(l, n, m);
|
||||
});
|
||||
Tensor d = Reduce("sum", {{24, "l1"}}, Sum(), c, {{32, "n1"}, {12, "m1"}});
|
||||
Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12});
|
||||
|
||||
Tensor e = Compute("scale", {{24, "l"}}, [&](const VarHandle& l) {
|
||||
Tensor e = Compute("scale", {24}, [&](const VarHandle& l) {
|
||||
return b.load(0, 0, l) * d.load(l);
|
||||
});
|
||||
|
||||
|
|
@ -1505,7 +1496,8 @@ TEST(Reductions, ReductionCacheBodyAccess) {
|
|||
l.cacheAccesses(c.buf(), "scale_local", d_loop);
|
||||
|
||||
l.prepareForCodegen();
|
||||
StmtPtr result = IRSimplifier::simplify(l.root_stmt());
|
||||
StmtPtr result =
|
||||
LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
|
||||
SimpleIREvaluator cg(result, {a, b, e});
|
||||
|
||||
std::ostringstream oss;
|
||||
|
|
@ -1513,11 +1505,11 @@ TEST(Reductions, ReductionCacheBodyAccess) {
|
|||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
#CHECK: Allocate(scale_local); // dtype=float, dims=[1, 32, 12]
|
||||
#CHECK: for (int j = 0; j < 32; j++) {
|
||||
#CHECK: for (int k = 0; k < 12; k++) {
|
||||
#CHECK: scale_local[k + 12 * j] = scale[(k + 12 * j) + 384 * l1];
|
||||
#CHECK: sum[l1] = (sum[l1]) + (scale_local[m1_1 + 12 * n1_1]);
|
||||
#CHECK: scale_1[l] = (b[l]) * (sum[l]);
|
||||
#CHECK: for (int j_1 = 0; j_1 < 32; j_1++) {
|
||||
#CHECK: for (int k_1 = 0; k_1 < 12; k_1++) {
|
||||
#CHECK: scale_local[k_1 + 12 * j_1] = scale[(k_1 + 12 * j_1) + 384 * i_1];
|
||||
#CHECK: sum[i_1] = (sum[i_1]) + (scale_local[k_2 + 12 * j_2]);
|
||||
#CHECK: scale_1[i_2] = (b[i_2]) * (sum[i_2]);
|
||||
#CHECK: Free(scale_local);
|
||||
)IR";
|
||||
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
||||
|
|
@ -1529,13 +1521,13 @@ TEST(Reductions, ReductionCacheConsumerAccess) {
|
|||
|
||||
Tensor c = Compute(
|
||||
"scale",
|
||||
{{24, "l2"}, {32, "n1"}, {12, "m1"}},
|
||||
{24, 32, 12},
|
||||
[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
|
||||
return b.load(l, n, m) * a.load(l, n, m);
|
||||
});
|
||||
Tensor d = Reduce("sum", {{24, "l1"}}, Sum(), c, {{32, "n1"}, {12, "m1"}});
|
||||
Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12});
|
||||
|
||||
Tensor e = Compute("scale", {{24, "l"}}, [&](const VarHandle& l) {
|
||||
Tensor e = Compute("scale", {24}, [&](const VarHandle& l) {
|
||||
return b.load(0, 0, l) * d.load(l);
|
||||
});
|
||||
|
||||
|
|
@ -1547,7 +1539,8 @@ TEST(Reductions, ReductionCacheConsumerAccess) {
|
|||
l.cacheAccesses(d.buf(), "sum_local", e_loop);
|
||||
l.prepareForCodegen();
|
||||
|
||||
StmtPtr result = IRSimplifier::simplify(l.root_stmt());
|
||||
StmtPtr result =
|
||||
LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
|
||||
SimpleIREvaluator cg(result, {a, b, e});
|
||||
|
||||
std::ostringstream oss;
|
||||
|
|
@ -1555,10 +1548,10 @@ TEST(Reductions, ReductionCacheConsumerAccess) {
|
|||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
#CHECK: Alias(sum_local,scale);
|
||||
#CHECK: sum[l1] = (sum[l1]) + (scale[
|
||||
#CHECK: for (int i = 0; i < 4
|
||||
#CHECK: sum_local[i] = sum[i + 4 * l_outer];
|
||||
#CHECK: scale_1[l_inner + 4 * l_outer] = (b[l_inner + 4 * l_outer]) * (sum_local[l_inner]);
|
||||
#CHECK: sum[i_1] = (sum[i_1]) + (scale[
|
||||
#CHECK: for (int j_2 = 0; j_2 < 4
|
||||
#CHECK: sum_local[j_2] = sum[j_2 + 4 * i_2];
|
||||
#CHECK: scale_1[j_3 + 4 * i_2] = (b[j_3 + 4 * i_2]) * (sum_local[j_3]);
|
||||
)IR";
|
||||
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
||||
}
|
||||
|
|
@ -1569,13 +1562,13 @@ TEST(Reductions, ReductionSplitCacheConsumerAccess) {
|
|||
|
||||
Tensor c = Compute(
|
||||
"scale",
|
||||
{{24, "l2"}, {32, "n1"}, {12, "m1"}},
|
||||
{24, 32, 12},
|
||||
[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
|
||||
return b.load(l, n, m) * a.load(l, n, m);
|
||||
});
|
||||
Tensor d = Reduce("sum", {{24, "l1"}}, Sum(), c, {{32, "n1"}, {12, "m1"}});
|
||||
Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12});
|
||||
|
||||
Tensor e = Compute("scale", {{24, "l"}}, [&](const VarHandle& l) {
|
||||
Tensor e = Compute("scale", {24}, [&](const VarHandle& l) {
|
||||
return b.load(0, 0, l) * d.load(l);
|
||||
});
|
||||
|
||||
|
|
@ -1593,7 +1586,8 @@ TEST(Reductions, ReductionSplitCacheConsumerAccess) {
|
|||
l.cacheAccesses(d.buf(), "sum_local", inner);
|
||||
l.prepareForCodegen();
|
||||
|
||||
StmtPtr result = IRSimplifier::simplify(l.root_stmt());
|
||||
StmtPtr result =
|
||||
LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
|
||||
SimpleIREvaluator cg(result, {a, b, e});
|
||||
|
||||
// reduction changes but cache does not.
|
||||
|
|
@ -1602,10 +1596,12 @@ TEST(Reductions, ReductionSplitCacheConsumerAccess) {
|
|||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
#CHECK: Alias(sum_local,scale);
|
||||
#CHECK: sum[l1_inner + 4 * l1_outer] = (sum[l1_inner + 4 * l1_outer]) + (scale[((m1_1 + 12 * n1_1) + 1536 * l1_outer) + 384 * l1_inner]);
|
||||
#CHECK: for (int i = 0; i < 4
|
||||
#CHECK: sum_local[i] = sum[i + 4 * l_outer];
|
||||
#CHECK: scale_1[l_inner + 4 * l_outer] = (b[l_inner + 4 * l_outer]) * (sum_local[l_inner]);
|
||||
#CHECK: sum[j_1 + 4 * i_1] = (sum[j_1 + 4 * i_1]) + (scale[((l + 12 * k_1) + 1536 * i_1) + 384 * j_1]);
|
||||
#CHECK: for (int i_2 = 0; i_2 < 6
|
||||
#CHECK: for (int j_2 = 0; j_2 < 4
|
||||
#CHECK: sum_local[j_2] = sum[j_2 + 4 * i_2];
|
||||
#CHECK: for (int j_3 = 0; j_3 < 4
|
||||
#CHECK: scale_1[j_3 + 4 * i_2] = (b[j_3 + 4 * i_2]) * (sum_local[j_3]);
|
||||
)IR";
|
||||
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
||||
}
|
||||
|
|
@ -1616,13 +1612,13 @@ TEST(Reductions, ReductionReorderCacheConsumerAccess) {
|
|||
|
||||
Tensor c = Compute(
|
||||
"scale",
|
||||
{{24, "l2"}, {32, "n1"}, {12, "m1"}},
|
||||
{24, 32, 12},
|
||||
[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
|
||||
return b.load(l, n, m) * a.load(l, n, m);
|
||||
});
|
||||
Tensor d = Reduce("sum", {{24, "l1"}}, Sum(), c, {{32, "n1"}, {12, "m1"}});
|
||||
Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12});
|
||||
|
||||
Tensor e = Compute("scale", {{24, "l"}}, [&](const VarHandle& l) {
|
||||
Tensor e = Compute("scale", {24}, [&](const VarHandle& l) {
|
||||
return b.load(0, 0, l) * d.load(l);
|
||||
});
|
||||
|
||||
|
|
@ -1641,7 +1637,8 @@ TEST(Reductions, ReductionReorderCacheConsumerAccess) {
|
|||
l.cacheAccesses(d.buf(), "sum_local", inner);
|
||||
l.prepareForCodegen();
|
||||
|
||||
StmtPtr result = IRSimplifier::simplify(l.root_stmt());
|
||||
StmtPtr result =
|
||||
LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
|
||||
SimpleIREvaluator cg(result, {a, b, e});
|
||||
|
||||
// neither reduction body not cache changes.
|
||||
|
|
@ -1649,10 +1646,12 @@ TEST(Reductions, ReductionReorderCacheConsumerAccess) {
|
|||
oss << *cg.stmt();
|
||||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
#CHECK: sum[l1] = (sum[l1]) + (scale[(m1_1 + 12 * n1_1) + 384 * l1]);
|
||||
#CHECK: for (int i = 0; i < 4
|
||||
#CHECK: sum_local[i] = sum[i + 4 * l_outer];
|
||||
#CHECK: scale_1[l_inner + 4 * l_outer] = (b[l_inner + 4 * l_outer]) * (sum_local[l_inner]);
|
||||
#CHECK: sum[j_1] = (sum[j_1]) + (scale[(k_1 + 12 * i_2) + 384 * j_1]);
|
||||
#CHECK: for (int i_3 = 0; i_3 < 6;
|
||||
#CHECK: for (int j_2 = 0; j_2 < 4;
|
||||
#CHECK: sum_local[j_2] = sum[j_2 + 4 * i_3];
|
||||
#CHECK: for (int j_3 = 0; j_3 < 4;
|
||||
#CHECK: scale_1[j_3 + 4 * i_3] = (b[j_3 + 4 * i_3]) * (sum_local[j_3]);
|
||||
)IR";
|
||||
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
||||
}
|
||||
|
|
@ -1673,7 +1672,7 @@ TEST(Reductions, ReductionRfactorCacheTempOuter) {
|
|||
|
||||
std::vector<float> out(1, -1.f);
|
||||
|
||||
Tensor c = Reduce("sum", {}, Sum(), b, {{m, "a"}, {n, "b"}, {k, "c"}});
|
||||
Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k});
|
||||
LoopNest loop({c});
|
||||
|
||||
std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
|
||||
|
|
@ -1693,7 +1692,7 @@ TEST(Reductions, ReductionRfactorCacheTempOuter) {
|
|||
LoopNest::cacheAccesses(rfac_buf, "tmp", all_loops[1][1]);
|
||||
loop.simplify();
|
||||
loop.prepareForCodegen();
|
||||
StmtPtr s = loop.root_stmt();
|
||||
StmtPtr s = LoopNest::sanitizeNames(loop.root_stmt());
|
||||
SimpleIREvaluator cg(s, {b, c, m, n, k});
|
||||
|
||||
std::ostringstream oss;
|
||||
|
|
@ -1702,17 +1701,17 @@ TEST(Reductions, ReductionRfactorCacheTempOuter) {
|
|||
R"IR(
|
||||
#CHECK: Allocate(sum_rfac); // dtype=float, dims=[n]
|
||||
#CHECK: Allocate(tmp); // dtype=float, dims=[n]
|
||||
#CHECK: for (int a = 0; a < m
|
||||
#CHECK: for (int i = 0; i < n
|
||||
#CHECK: tmp[i] = 0
|
||||
#CHECK: for (int i_1 = 0; i_1 < m
|
||||
#CHECK: for (int j = 0; j < n
|
||||
#CHECK: tmp[j] = 0
|
||||
#CHECK: }
|
||||
#CHECK: for (int b = 0; b < n
|
||||
#CHECK: for (int c
|
||||
#CHECK: tmp[b] = (tmp[b]) + (B[
|
||||
#CHECK: for (int j_1 = 0; j_1 < n
|
||||
#CHECK: for (int k
|
||||
#CHECK: tmp[j_1] = (tmp[j_1]) + (B[
|
||||
#CHECK: }
|
||||
#CHECK: }
|
||||
#CHECK: for (int i = 0; i < n
|
||||
#CHECK: sum_rfac[i] = (sum_rfac[i]) + (tmp[i]);
|
||||
#CHECK: for (int j_2 = 0; j_2 < n
|
||||
#CHECK: sum_rfac[j_2] = (sum_rfac[j_2]) + (tmp[j_2]);
|
||||
#CHECK: }
|
||||
#CHECK: Free(tmp);
|
||||
#CHECK-NOT: tmp
|
||||
|
|
@ -1739,7 +1738,7 @@ TEST(Reductions, ReductionRfactorCacheTempInner) {
|
|||
|
||||
std::vector<float> out(1, -1.f);
|
||||
|
||||
Tensor c = Reduce("sum", {}, Sum(), b, {{m, "a"}, {n, "b"}, {k, "c"}});
|
||||
Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k});
|
||||
LoopNest loop({c});
|
||||
std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
|
||||
auto c_body = loop.getAllWritesToBuf(c.buf())[1];
|
||||
|
|
@ -1759,7 +1758,7 @@ TEST(Reductions, ReductionRfactorCacheTempInner) {
|
|||
LoopNest::cacheAccesses(rfac_buf, "tmp", all_loops[1][2]);
|
||||
loop.prepareForCodegen();
|
||||
loop.simplify();
|
||||
StmtPtr s = loop.root_stmt();
|
||||
StmtPtr s = LoopNest::sanitizeNames(loop.root_stmt());
|
||||
SimpleIREvaluator cg(s, {b, c, m, n, k});
|
||||
|
||||
std::ostringstream oss;
|
||||
|
|
@ -1768,13 +1767,13 @@ TEST(Reductions, ReductionRfactorCacheTempInner) {
|
|||
R"IR(
|
||||
#CHECK: Allocate(sum_rfac); // dtype=float, dims=[n]
|
||||
#CHECK: Allocate(tmp); // dtype=float, dims=[1]
|
||||
#CHECK: for (int a = 0; a < m
|
||||
#CHECK: for (int b = 0; b < n
|
||||
#CHECK: for (int i_1 = 0; i_1 < m
|
||||
#CHECK: for (int j = 0; j < n
|
||||
#CHECK: tmp[0] = 0
|
||||
#CHECK: for (int c
|
||||
#CHECK: for (int k
|
||||
#CHECK: tmp[0] = (tmp[0]) + (B[
|
||||
#CHECK: }
|
||||
#CHECK: sum_rfac[b] = (sum_rfac[b]) + (tmp[0]);
|
||||
#CHECK: sum_rfac[j] = (sum_rfac[j]) + (tmp[0]);
|
||||
#CHECK: Free(tmp);
|
||||
#CHECK-NOT: tmp
|
||||
)IR";
|
||||
|
|
@ -1796,7 +1795,7 @@ TEST(Reductions, ReductionVectorize) {
|
|||
|
||||
BufHandle in("in", {8, 8}, kFloat);
|
||||
|
||||
Tensor tensor = Reduce("sum", {{8, "m"}}, Sum(), in, {{8, "n"}});
|
||||
Tensor tensor = Reduce("sum", {8}, Sum(), in, {8});
|
||||
LoopNest l_before({tensor});
|
||||
LoopNest l(l_before);
|
||||
l_before.prepareForCodegen();
|
||||
|
|
@ -1806,15 +1805,15 @@ TEST(Reductions, ReductionVectorize) {
|
|||
ASSERT_TRUE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[0]));
|
||||
|
||||
StmtPtr s = l.root_stmt();
|
||||
s = IRSimplifier::simplify(s);
|
||||
s = LoopNest::sanitizeNames(IRSimplifier::simplify(s));
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << *s;
|
||||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
#CHECK: sum[Ramp(0, 1, 8)] = Broadcast(0.f, 8);
|
||||
#CHECK: for (int n = 0; n < 8; n++) {
|
||||
#CHECK: sum[Ramp(0, 1, 8)] = ReduceOp((sum[Ramp(0, 1, 8)]) + (in[Ramp(n, 8, 8)]), reduce_args={n});
|
||||
#CHECK: for (int i = 0; i < 8; i++) {
|
||||
#CHECK: sum[Ramp(0, 1, 8)] = ReduceOp((sum[Ramp(0, 1, 8)]) + (in[Ramp(i, 8, 8)]), reduce_args={i});
|
||||
#CHECK: }
|
||||
)IR";
|
||||
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
||||
|
|
@ -1832,7 +1831,7 @@ TEST(Reductions, ReductionVectorize) {
|
|||
TEST(Reductions, ReductionVectorizeInner) {
|
||||
BufHandle in("in", {8, 8}, kFloat);
|
||||
|
||||
Tensor tensor = Reduce("sum", {{8, "m"}}, Sum(), in, {{8, "n"}});
|
||||
Tensor tensor = Reduce("sum", {8}, Sum(), in, {8});
|
||||
LoopNest l({tensor});
|
||||
|
||||
ASSERT_FALSE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[1]));
|
||||
|
|
@ -1850,7 +1849,7 @@ TEST(Reductions, ReductionVectorizeRfactor) {
|
|||
|
||||
BufHandle in("in", {8, 8}, kFloat);
|
||||
|
||||
Tensor tensor = Reduce("sum", {}, Sum(), in, {{8, "m"}, {8, "n"}});
|
||||
Tensor tensor = Reduce("sum", {}, Sum(), in, {8, 8});
|
||||
|
||||
LoopNest l_before({tensor});
|
||||
LoopNest l(l_before);
|
||||
|
|
@ -1875,21 +1874,21 @@ TEST(Reductions, ReductionVectorizeRfactor) {
|
|||
ASSERT_TRUE(LoopNest::vectorize(rfac_loops[1][0]));
|
||||
l.simplify();
|
||||
|
||||
StmtPtr s = l.root_stmt();
|
||||
StmtPtr s = LoopNest::sanitizeNames(l.root_stmt());
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << *s;
|
||||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
#CHECK: sum = 0.f;
|
||||
#CHECK: for (int n = 0; n < 8; n++) {
|
||||
#CHECK: sum_rfac[n] = 0.f;
|
||||
#CHECK: for (int i = 0; i < 8; i++) {
|
||||
#CHECK: sum_rfac[i] = 0.f;
|
||||
#CHECK: }
|
||||
#CHECK: for (int m = 0; m < 8; m++) {
|
||||
#CHECK: sum_rfac[Ramp(0, 1, 8)] = ReduceOp((sum_rfac[Ramp(0, 1, 8)]) + (in[Ramp(8 * m, 1, 8)]), reduce_args={m});
|
||||
#CHECK: for (int i_1 = 0; i_1 < 8; i_1++) {
|
||||
#CHECK: sum_rfac[Ramp(0, 1, 8)] = ReduceOp((sum_rfac[Ramp(0, 1, 8)]) + (in[Ramp(8 * i_1, 1, 8)]), reduce_args={i_1});
|
||||
#CHECK: }
|
||||
#CHECK: for (int n = 0; n < 8; n++) {
|
||||
#CHECK: sum = ReduceOp((sum) + (sum_rfac[n]), reduce_args={n});
|
||||
#CHECK: for (int i_2 = 0; i_2 < 8; i_2++) {
|
||||
#CHECK: sum = ReduceOp((sum) + (sum_rfac[i_2]), reduce_args={i_2});
|
||||
#CHECK: }
|
||||
)IR";
|
||||
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
|
||||
|
|
@ -1910,22 +1909,22 @@ TEST(Reductions, InitFunction) {
|
|||
BufHandle B("B", {N}, kFloat);
|
||||
Tensor C = Reduce(
|
||||
"C",
|
||||
{{N, "n"}},
|
||||
{N},
|
||||
Sum(),
|
||||
[&](const std::vector<VarHandle>& v) { return B.load(v[0]); },
|
||||
[&](const std::vector<VarHandle>& v) { return A.load(v[1], v[0]); },
|
||||
{{M, "m"}});
|
||||
{M});
|
||||
LoopNest nest({C});
|
||||
nest.prepareForCodegen();
|
||||
StmtPtr s = IRSimplifier::simplify(nest.root_stmt());
|
||||
StmtPtr s = LoopNest::sanitizeNames(IRSimplifier::simplify(nest.root_stmt()));
|
||||
std::ostringstream oss;
|
||||
oss << *s << "\n";
|
||||
const std::string& expected_ir =
|
||||
R"IR(
|
||||
#CHECK: for (int n = 0; n < 16; n++) {
|
||||
#CHECK: C[n] = B[n];
|
||||
#CHECK: for (int m = 0; m < 32; m++) {
|
||||
#CHECK: C[n] = (C[n]) + (A[n + 16 * m]);
|
||||
#CHECK: for (int i = 0; i < 16; i++) {
|
||||
#CHECK: C[i] = B[i];
|
||||
#CHECK: for (int j = 0; j < 32; j++) {
|
||||
#CHECK: C[i] = (C[i]) + (A[i + 16 * j]);
|
||||
#CHECK: }
|
||||
#CHECK: }
|
||||
)IR";
|
||||
|
|
|
|||
|
|
@ -3858,26 +3858,25 @@ TEST(Simplify, SimplifyForCleansUp) {
|
|||
BufHandle a("a", {1, 12, 1}, kFloat);
|
||||
VarHandle x("x", kInt);
|
||||
Tensor b = Compute(
|
||||
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
"x",
|
||||
{{1, "i"}, {12, "m"}, {1, "n"}},
|
||||
{1, 12, 1},
|
||||
[](const VarHandle& i, const VarHandle& m, const VarHandle& n) {
|
||||
return i + m + n;
|
||||
});
|
||||
LoopNest l({b});
|
||||
l.prepareForCodegen();
|
||||
|
||||
StmtPtr body = l.root_stmt();
|
||||
StmtPtr body = LoopNest::sanitizeNames(l.root_stmt());
|
||||
StmtPtr simplified = IRSimplifier::simplify(body);
|
||||
|
||||
BlockPtr block = to<Block>(simplified);
|
||||
IS_NODE_WITH_NAME(For, block->front(), for_);
|
||||
// for is over "m".
|
||||
IS_VAR_WITH_NAME(for_->var(), "m");
|
||||
IS_VAR_WITH_NAME(for_->var(), "j");
|
||||
// x[m] = m;
|
||||
IS_NODE_WITH_NAME(Store, for_->body()->front(), store);
|
||||
IS_VAR_WITH_NAME(store->flat_index(), "m");
|
||||
IS_VAR_WITH_NAME(store->value(), "m");
|
||||
IS_VAR_WITH_NAME(store->flat_index(), "j");
|
||||
IS_VAR_WITH_NAME(store->value(), "j");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -186,10 +186,10 @@ int main(int argc, char* argv[]) {
|
|||
// structure is simply a pair of a buffer that was created to represent the
|
||||
// result of the computation (BufPtr) and a statement representing the
|
||||
// computation itself (StmtPtr).
|
||||
Tensor C = Compute(
|
||||
"C",
|
||||
{{64, "i"}, {32, "j"}},
|
||||
[&](const VarHandle& i, const VarHandle& j) { return i * j; });
|
||||
Tensor C =
|
||||
Compute("C", {64, 32}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return i * j;
|
||||
});
|
||||
std::cout << "Stmt produced by 'Compute' API: " << std::endl
|
||||
<< *C.stmt() << std::endl;
|
||||
// Prints:
|
||||
|
|
@ -209,7 +209,7 @@ int main(int argc, char* argv[]) {
|
|||
{},
|
||||
Sum(),
|
||||
[&](const VarHandle& i, const VarHandle& j) { return C.load(i, j); },
|
||||
{{64, "i"}, {32, "j"}});
|
||||
{64, 32});
|
||||
std::cout << "Stmt produced by 'Reduce' API: " << std::endl
|
||||
<< *D.stmt() << std::endl;
|
||||
}
|
||||
|
|
@ -223,15 +223,13 @@ int main(int argc, char* argv[]) {
|
|||
// Let's look at a couple of transformations that are used in NNC. We will
|
||||
// begin with constructing a Block statement like we did before.
|
||||
|
||||
Tensor C = Compute(
|
||||
"C",
|
||||
{{64, "i"}, {32, "j"}},
|
||||
[&](const VarHandle& i, const VarHandle& j) { return i * (j + 1); });
|
||||
Tensor C =
|
||||
Compute("C", {64, 32}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return i * (j + 1);
|
||||
});
|
||||
BufHandle c_buf(C.buf());
|
||||
Tensor D = Compute(
|
||||
"D",
|
||||
{{64, "i"}, {32, "j"}},
|
||||
[&](const VarHandle& i, const VarHandle& j) {
|
||||
Tensor D =
|
||||
Compute("D", {64, 32}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return c_buf.load(i, j) - i;
|
||||
});
|
||||
StmtPtr block = Block::make({C.stmt(), D.stmt()});
|
||||
|
|
@ -353,10 +351,8 @@ int main(int argc, char* argv[]) {
|
|||
// Let's start by constructing a simple computation for us to work with:
|
||||
BufHandle A("A", {64, 32}, kInt);
|
||||
BufHandle B("B", {64, 32}, kInt);
|
||||
Tensor X = Compute(
|
||||
"X",
|
||||
{{64, "i"}, {32, "j"}},
|
||||
[&](const VarHandle& i, const VarHandle& j) {
|
||||
Tensor X =
|
||||
Compute("X", {64, 32}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return A.load(i, j) + B.load(i, j);
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -349,19 +349,13 @@ graph(%x : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu)):
|
|||
graph = torch._C.parse_ir(graph_str)
|
||||
|
||||
def my_custom_lowering(inputs, out_shape, out_type, device):
|
||||
def get_dim_args(dims):
|
||||
dim_args = []
|
||||
for dim in dims:
|
||||
dim_args.append(te.DimArg(dim, "i" + str(len(dim_args))))
|
||||
return dim_args
|
||||
|
||||
def compute(idxs):
|
||||
load = inputs[0].as_buf().load(idxs)
|
||||
return te.ifThenElse(
|
||||
te.ExprHandle.isnan(load), te.ExprHandle.float(0.0), load
|
||||
)
|
||||
|
||||
return te.Compute2("custom_nan_to_num", get_dim_args(out_shape), compute)
|
||||
return te.Compute2("custom_nan_to_num", out_shape, compute)
|
||||
|
||||
kernel = te.TensorExprKernel(graph, {"aten::nan_to_num": my_custom_lowering})
|
||||
res1 = kernel.run((x,))
|
||||
|
|
|
|||
|
|
@ -1,33 +0,0 @@
|
|||
#pragma once
|
||||
#include <torch/csrc/jit/tensorexpr/expr.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
// A helper structure to store the arguments to specify dimensions. In the
|
||||
// Compute arguments for dim_args, all of the following is supported. For
|
||||
// example:
|
||||
// dim_args: {1, 2, 3, 4}
|
||||
// dim_args: {{1, "x"}, {2, "y"}, {3, "z"}}
|
||||
// dim_args: {1, 2, {3, "x"}}
|
||||
class DimArg {
|
||||
public:
|
||||
// Intentionally leave out explicit to allow implicit conversions.
|
||||
DimArg(const ExprHandle& dim) : dim_(dim) {}
|
||||
DimArg(const ExprHandle& dim, std::string name_hint)
|
||||
: dim_(dim), name_hint_(std::move(name_hint)) {}
|
||||
const ExprHandle& dim() const {
|
||||
return dim_;
|
||||
}
|
||||
const std::string& name_hint() const {
|
||||
return name_hint_;
|
||||
}
|
||||
|
||||
private:
|
||||
ExprHandle dim_;
|
||||
std::string name_hint_;
|
||||
};
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -402,7 +402,7 @@ void IRPrinter::visit(ReduceOpPtr v) {
|
|||
if (!first) {
|
||||
os() << ", ";
|
||||
}
|
||||
os() << d->name_hint();
|
||||
os() << *d;
|
||||
first = false;
|
||||
}
|
||||
os() << "})";
|
||||
|
|
|
|||
|
|
@ -1033,12 +1033,7 @@ Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) {
|
|||
// if the input isn't contiguous or is an output,
|
||||
// write strided input into contiguous buffer that is
|
||||
// then used in all further compute
|
||||
std::vector<DimArg> inputTensorDims;
|
||||
auto size_handles = sizesFromSymbolicShape(tt->symbolic_sizes());
|
||||
for (size_t i = 0; i < size_handles.size(); i++) {
|
||||
auto size = size_handles[i];
|
||||
inputTensorDims.emplace_back(DimArg(size, "i" + c10::to_string(i)));
|
||||
}
|
||||
auto inputTensorStrides = getInputStrides(input, size_handles);
|
||||
ExprHandle flat_size = 1;
|
||||
for (size_t i = 0; i < size_handles.size(); ++i) {
|
||||
|
|
@ -1057,7 +1052,7 @@ Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) {
|
|||
|
||||
result = Compute(
|
||||
"input" + c10::to_string(bufs_.size() + 1),
|
||||
inputTensorDims,
|
||||
size_handles,
|
||||
[&](const std::vector<VarHandle>& axes) {
|
||||
ExprHandle idx = 0;
|
||||
for (size_t i = 0; i < axes.size(); i++) {
|
||||
|
|
@ -1144,11 +1139,10 @@ Tensor TensorExprKernel::convertSymbolicOutputToCorrectStrides(
|
|||
// for stride in strides_from_largest_to_smallest:
|
||||
// cur_idx = absolute // stride
|
||||
// absolute = absolute % stride
|
||||
auto dims = c10::fmap<DimArg>(sizes);
|
||||
std::vector<ExprPtr> default_strides = make_contiguous_strides(sizes);
|
||||
auto zero = LongImm::make(0);
|
||||
return Compute(
|
||||
"output_1", dims, [&](const std::vector<VarHandle>& axes_input) {
|
||||
"output_1", sizes, [&](const std::vector<VarHandle>& axes_input) {
|
||||
std::vector<ExprHandle> axes(axes_input.begin(), axes_input.end());
|
||||
auto absolute_position = ExprHandle(immLike(axes[0], 0));
|
||||
for (size_t i = 0; i < axes.size(); ++i) {
|
||||
|
|
@ -1191,7 +1185,6 @@ Tensor TensorExprKernel::convertSymbolicOutputToCorrectStrides(
|
|||
tensorOutputStrideDesc_[v->offset()] ==
|
||||
torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST);
|
||||
auto sizes = sizesFromSymbolicShape(tt->symbolic_sizes());
|
||||
auto dims = c10::fmap<DimArg>(sizes);
|
||||
auto strides = make_channels_last_strides(sizes);
|
||||
// For a tensor with dimensions N C H W, channels last
|
||||
// format will is in format N H W C,
|
||||
|
|
@ -1243,7 +1236,7 @@ Tensor TensorExprKernel::convertStaticShapeOutputToCorrectStrides(
|
|||
return Tensor(buf, nullptr);
|
||||
}
|
||||
|
||||
auto dims = c10::fmap<DimArg>(sizesForValue(v));
|
||||
auto dims = sizesForValue(v);
|
||||
auto zero = LongImm::make(0);
|
||||
std::vector<size_t> sorted_stride_indices = reverse_sort_indices(strides);
|
||||
|
||||
|
|
|
|||
|
|
@ -198,7 +198,6 @@ class TORCH_API TensorExprKernel {
|
|||
void genInputDebugNames();
|
||||
void runKernel(Stack& stack);
|
||||
|
||||
std::vector<DimArg> dimsFromSizes(const std::vector<ExprHandle>& sizes);
|
||||
std::vector<ExprHandle> sizesForValue(const torch::jit::Value* v);
|
||||
|
||||
// These functions broadcast shape and also store a `hasBroadcast_` variable.
|
||||
|
|
|
|||
|
|
@ -37,11 +37,13 @@ namespace tensorexpr {
|
|||
LoopNest::LoopNest(const LoopNest& other)
|
||||
: root_stmt_(Stmt::clone(other.root_stmt_)),
|
||||
output_bufs_(other.output_bufs_) {
|
||||
GRAPH_DEBUG("Origin Stmt in LoopNest:\n", std::to_string(root_stmt_));
|
||||
verify(root_stmt_);
|
||||
}
|
||||
|
||||
LoopNest::LoopNest(StmtPtr stmt, std::unordered_set<BufPtr> output_bufs)
|
||||
: root_stmt_(stmt), output_bufs_(std::move(output_bufs)) {
|
||||
GRAPH_DEBUG("Origin Stmt in LoopNest:\n", std::to_string(root_stmt_));
|
||||
verify(root_stmt_);
|
||||
}
|
||||
|
||||
|
|
@ -50,12 +52,14 @@ LoopNest::LoopNest(
|
|||
const std::vector<Tensor>& output_tensors,
|
||||
const std::vector<Tensor>& tensors_to_compute) {
|
||||
initialize(output_tensors, tensors_to_compute);
|
||||
GRAPH_DEBUG("Origin Stmt in LoopNest:\n", std::to_string(root_stmt_));
|
||||
verify(root_stmt_);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
LoopNest::LoopNest(const std::vector<Tensor>& output_tensors) {
|
||||
initialize(output_tensors, output_tensors);
|
||||
GRAPH_DEBUG("Origin Stmt in LoopNest:\n", std::to_string(root_stmt_));
|
||||
verify(root_stmt_);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -953,9 +953,7 @@ int nnc_lowerings_lazy_registration() {
|
|||
auto const& shape =
|
||||
broadcastShapes(valueShape(inputs[0]), valueShape(inputs[1]));
|
||||
return Compute(
|
||||
"aten_remainder",
|
||||
c10::fmap<DimArg>(shape),
|
||||
[&](const std::vector<VarHandle>& axes) {
|
||||
"aten_remainder", shape, [&](const std::vector<VarHandle>& axes) {
|
||||
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
||||
std::vector<ExprHandle> exprInputs = {
|
||||
tensorOrConstant(inputs[0], indices),
|
||||
|
|
@ -1454,7 +1452,7 @@ int nnc_lowerings_lazy_registration() {
|
|||
// at::Device device) {
|
||||
// return Compute(
|
||||
// "aten_slice",
|
||||
// c10::fmap<DimArg>(outputShape),
|
||||
// outputShape,
|
||||
// [&](const std::vector<VarHandle>& axes) {
|
||||
// int64_t dim =
|
||||
// at::maybe_wrap_dim(c10::get<int64_t>(inputs[1]),
|
||||
|
|
@ -1475,7 +1473,7 @@ int nnc_lowerings_lazy_registration() {
|
|||
at::Device device) {
|
||||
return Compute(
|
||||
"aten_unsqueeze",
|
||||
c10::fmap<DimArg>(outputShape),
|
||||
outputShape,
|
||||
[&](const std::vector<VarHandle>& axes) {
|
||||
int64_t dim = c10::get<int64_t>(inputs[1]);
|
||||
if (dim < 0) {
|
||||
|
|
@ -1525,7 +1523,7 @@ int nnc_lowerings_lazy_registration() {
|
|||
if (A.ndim() == 0) {
|
||||
auto tensor = Compute(
|
||||
"aten_permute",
|
||||
c10::fmap<DimArg>(outputShape),
|
||||
outputShape,
|
||||
[&](const std::vector<VarHandle>& axes) {
|
||||
std::vector<ExprHandle> empty_indices;
|
||||
return A.load(empty_indices);
|
||||
|
|
@ -1539,7 +1537,7 @@ int nnc_lowerings_lazy_registration() {
|
|||
auto permute_dims = c10::get<IntList>(inputs[1]);
|
||||
auto tensor = Compute(
|
||||
"aten_permute",
|
||||
c10::fmap<DimArg>(outputShape),
|
||||
outputShape,
|
||||
[&](const std::vector<VarHandle>& axes) {
|
||||
std::vector<VarHandle> new_axes;
|
||||
new_axes.resize(axes.size());
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ Tensor conv2d_depthwise_static(
|
|||
|
||||
Tensor conv = Reduce(
|
||||
"conv2d_depthwise",
|
||||
{{N, "n"}, {K, "k"}, {OH, "oh"}, {OW, "ow"}},
|
||||
{N, K, OH, OW},
|
||||
Sum(),
|
||||
[&](const std::vector<VarHandle>& v) { return init_func(v); },
|
||||
[&](const std::vector<VarHandle>& v) {
|
||||
|
|
@ -70,7 +70,7 @@ Tensor conv2d_depthwise_static(
|
|||
input.load(n, k, oh * stride - pad + r, ow * stride - pad + s));
|
||||
return in * weight.load(k, c, r, s);
|
||||
},
|
||||
{{C / groups, "c"}, {R, "r"}, {S, "s"}});
|
||||
{C / groups, R, S});
|
||||
|
||||
LoopNest nest({conv});
|
||||
|
||||
|
|
@ -120,7 +120,7 @@ Tensor conv2d_depthwise_dynamic(
|
|||
|
||||
return Reduce(
|
||||
"conv2d_depthwise",
|
||||
{{N, "n"}, {K, "k"}, {OH, "oh"}, {OW, "ow"}},
|
||||
{N, K, OH, OW},
|
||||
Sum(),
|
||||
[&](const std::vector<VarHandle>& v) { return init_func(v); },
|
||||
[&](const std::vector<VarHandle>& v) {
|
||||
|
|
@ -141,7 +141,7 @@ Tensor conv2d_depthwise_dynamic(
|
|||
input.load(n, k, oh * stride - pad + r, ow * stride - pad + s));
|
||||
return in * weight.load(k, c, r, s);
|
||||
},
|
||||
{{C / groups, "c"}, {R, "r"}, {S, "s"}});
|
||||
{C / groups, R, S});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
|
|||
|
|
@ -38,12 +38,12 @@ Tensor computeMatmul(
|
|||
if (total_size && total_size->value() < 1000) {
|
||||
return Reduce(
|
||||
"nnc_matmul",
|
||||
{{size_a[0], "M"}, {size_b[1], "N"}},
|
||||
{size_a[0], size_b[1]},
|
||||
Sum(),
|
||||
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
|
||||
return Load::make(a, {m, k}) * Load::make(b, {k, n});
|
||||
},
|
||||
{{size_a[1], "K"}});
|
||||
{size_a[1]});
|
||||
} else {
|
||||
return Tensor(
|
||||
ResultBuf.node(),
|
||||
|
|
|
|||
|
|
@ -324,7 +324,7 @@ Tensor computeChunk(
|
|||
at::Device device) {
|
||||
return Compute(
|
||||
"prim_constantchunk",
|
||||
c10::fmap<DimArg>(outputShape),
|
||||
outputShape,
|
||||
[inputs](const std::vector<VarHandle>& axes) {
|
||||
const auto& b = c10::get<BufHandle>(inputs[0]);
|
||||
int64_t chunkIdx = c10::get<int64_t>(inputs[1]);
|
||||
|
|
@ -359,9 +359,7 @@ Tensor computeTranspose(
|
|||
// Trivial case of 0-dim and 1-dim tensors: transpose is just a copy
|
||||
if (A.ndim() <= 1) {
|
||||
return Compute(
|
||||
"aten_transpose",
|
||||
c10::fmap<DimArg>(outputShape),
|
||||
[&](std::vector<VarHandle> axes) {
|
||||
"aten_transpose", outputShape, [&](std::vector<VarHandle> axes) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
axes.size() <= 1,
|
||||
buildErrorMessage("Invalid axes size in transpose"));
|
||||
|
|
@ -372,9 +370,7 @@ Tensor computeTranspose(
|
|||
auto start_dim = at::maybe_wrap_dim(c10::get<int64_t>(inputs[1]), A.ndim());
|
||||
auto to_dim = at::maybe_wrap_dim(c10::get<int64_t>(inputs[2]), A.ndim());
|
||||
return Compute(
|
||||
"aten_transpose",
|
||||
c10::fmap<DimArg>(outputShape),
|
||||
[&](std::vector<VarHandle> axes) {
|
||||
"aten_transpose", outputShape, [&](std::vector<VarHandle> axes) {
|
||||
std::swap(axes[start_dim], axes[to_dim]);
|
||||
return A.load(axes);
|
||||
});
|
||||
|
|
@ -387,9 +383,7 @@ Tensor computeExpand(
|
|||
at::Device device) {
|
||||
auto A = c10::get<BufHandle>(inputs[0]);
|
||||
return Compute(
|
||||
"aten_expand",
|
||||
c10::fmap<DimArg>(outputShape),
|
||||
[&](const std::vector<VarHandle>& axes) {
|
||||
"aten_expand", outputShape, [&](const std::vector<VarHandle>& axes) {
|
||||
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
||||
return broadcast(A, indices);
|
||||
});
|
||||
|
|
@ -403,17 +397,13 @@ Tensor computeReshape(
|
|||
auto A = c10::get<BufHandle>(inputs[0]);
|
||||
if (A.ndim() == 0) {
|
||||
return Compute(
|
||||
"aten_view",
|
||||
c10::fmap<DimArg>(outputShape),
|
||||
[&](const std::vector<VarHandle>& axes) {
|
||||
"aten_view", outputShape, [&](const std::vector<VarHandle>& axes) {
|
||||
std::vector<ExprHandle> empty_indices;
|
||||
return A.load(empty_indices);
|
||||
});
|
||||
}
|
||||
return Compute(
|
||||
"aten_reshape",
|
||||
c10::fmap<DimArg>(outputShape),
|
||||
[&](const std::vector<VarHandle>& axes) {
|
||||
"aten_reshape", outputShape, [&](const std::vector<VarHandle>& axes) {
|
||||
std::vector<VarHandle> new_axes;
|
||||
assert(outputShape.size() == axes.size());
|
||||
/*
|
||||
|
|
@ -608,9 +598,7 @@ Tensor computeCat(
|
|||
ScalarType highType = catInfo.first;
|
||||
std::vector<BufHandle> nonEmptyInputs = catInfo.second;
|
||||
return Compute(
|
||||
"aten_cat",
|
||||
c10::fmap<DimArg>(outputShape),
|
||||
[&](const std::vector<VarHandle>& axes) {
|
||||
"aten_cat", outputShape, [&](const std::vector<VarHandle>& axes) {
|
||||
if (nonEmptyInputs.size() == 0) {
|
||||
return ExprHandle(0);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,9 +22,7 @@ Tensor computeBatchNorm(
|
|||
}
|
||||
|
||||
return Compute(
|
||||
"aten_batch_norm",
|
||||
c10::fmap<DimArg>(outputShape),
|
||||
[&](const std::vector<VarHandle>& axes) {
|
||||
"aten_batch_norm", outputShape, [&](const std::vector<VarHandle>& axes) {
|
||||
TORCH_INTERNAL_ASSERT(axes.size() >= 2);
|
||||
// axes: N, C, H, W
|
||||
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
||||
|
|
|
|||
|
|
@ -10,16 +10,15 @@ using namespace torch::jit::tensorexpr;
|
|||
Tensor computeSign(
|
||||
const std::vector<ArgValue>& inputValues,
|
||||
const std::vector<ExprHandle>& outputShape) {
|
||||
return Compute(
|
||||
"aten_sign", c10::fmap<DimArg>(outputShape), [&](ParameterList& axes) {
|
||||
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
||||
std::vector<ExprHandle> inputs = {
|
||||
tensorOrConstant(inputValues[0], indices)};
|
||||
auto inp = inputs[0];
|
||||
auto zero = ExprHandle(immLike(inp, 0.0f));
|
||||
auto res = (zero < inp) - (inp < zero);
|
||||
return promoteToDtype(res, inp.dtype().scalar_type());
|
||||
});
|
||||
return Compute("aten_sign", outputShape, [&](ParameterList& axes) {
|
||||
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
||||
std::vector<ExprHandle> inputs = {
|
||||
tensorOrConstant(inputValues[0], indices)};
|
||||
auto inp = inputs[0];
|
||||
auto zero = ExprHandle(immLike(inp, 0.0f));
|
||||
auto res = (zero < inp) - (inp < zero);
|
||||
return promoteToDtype(res, inp.dtype().scalar_type());
|
||||
});
|
||||
}
|
||||
|
||||
Tensor computeOneOperand(
|
||||
|
|
@ -31,7 +30,7 @@ Tensor computeOneOperand(
|
|||
const int checkParamTypes) {
|
||||
return Compute(
|
||||
name,
|
||||
c10::fmap<DimArg>(outputShape),
|
||||
outputShape,
|
||||
[inputValues, outputType, innerExpr, checkParamTypes](
|
||||
const std::vector<VarHandle>& axes) {
|
||||
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
||||
|
|
@ -52,7 +51,7 @@ Tensor computeTwoOperand(
|
|||
innerExpr) {
|
||||
return Compute(
|
||||
name,
|
||||
c10::fmap<DimArg>(outputShape),
|
||||
outputShape,
|
||||
[inputValues, outputType, innerExpr](const std::vector<VarHandle>& axes) {
|
||||
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
||||
std::vector<ExprHandle> inputs = {
|
||||
|
|
@ -75,7 +74,7 @@ Tensor computeTwoOperandWithAlpha(
|
|||
innerExpr) {
|
||||
return Compute(
|
||||
name,
|
||||
c10::fmap<DimArg>(outputShape),
|
||||
outputShape,
|
||||
[inputValues, outputType, innerExpr](const std::vector<VarHandle>& axes) {
|
||||
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
||||
std::vector<ExprHandle> inputs = {
|
||||
|
|
@ -100,7 +99,7 @@ Tensor computeConditionWithTwoOperand(
|
|||
innerExpr) {
|
||||
return Compute(
|
||||
name,
|
||||
c10::fmap<DimArg>(outputShape),
|
||||
outputShape,
|
||||
[inputValues, outputType, innerExpr](const std::vector<VarHandle>& axes) {
|
||||
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
||||
std::vector<ExprHandle> inputs = {
|
||||
|
|
@ -128,7 +127,7 @@ Tensor computeThreeOperand(
|
|||
bool promote_inputs) {
|
||||
return Compute(
|
||||
name,
|
||||
c10::fmap<DimArg>(outputShape),
|
||||
outputShape,
|
||||
[inputValues, outputType, innerExpr, promote_inputs](
|
||||
const std::vector<VarHandle>& axes) {
|
||||
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
||||
|
|
@ -157,7 +156,7 @@ Tensor computeFourOperand(
|
|||
const ExprHandle&)>& innerExpr) {
|
||||
return Compute(
|
||||
name,
|
||||
c10::fmap<DimArg>(outputShape),
|
||||
outputShape,
|
||||
[inputValues, outputType, innerExpr](const std::vector<VarHandle>& axes) {
|
||||
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
||||
std::vector<ExprHandle> inputs = {
|
||||
|
|
|
|||
|
|
@ -727,9 +727,7 @@ Tensor computeUpsampleNearest2d(
|
|||
auto input_height = ExprHandle(A.dim(2));
|
||||
auto input_width = ExprHandle(A.dim(3));
|
||||
|
||||
std::vector<ExprHandle> dims;
|
||||
std::vector<VarHandle> args;
|
||||
unpack_dim_args(c10::fmap<DimArg>(outputShape), &dims, &args);
|
||||
std::vector<VarHandle> args = create_index_vars(outputShape);
|
||||
// Handle separately when scale is specified? as in 'scalar_t
|
||||
// compute_scales_value' in UpSample.h
|
||||
auto scale_h =
|
||||
|
|
|
|||
|
|
@ -53,12 +53,12 @@ Tensor computeSum(
|
|||
std::iota(axes.begin(), axes.end(), 0);
|
||||
}
|
||||
// Axes go into reduction dimensions.
|
||||
std::vector<DimArg> reductionDims;
|
||||
std::vector<ExprHandle> reductionDims;
|
||||
reductionDims.reserve(rank);
|
||||
for (size_t axis : axes) {
|
||||
reductionDims.emplace_back(sizes[axis]);
|
||||
}
|
||||
std::vector<DimArg> outputDims;
|
||||
std::vector<ExprHandle> outputDims;
|
||||
// Output dimensions are the complement of axes. When keepdim is set, a
|
||||
// one-sized dimension is inserted for each axis.
|
||||
for (size_t dim = 0; dim < rank; ++dim) {
|
||||
|
|
|
|||
|
|
@ -40,7 +40,6 @@ Tensor computeSoftmax(
|
|||
// - Final loop computes the log_softmax for every element in v.
|
||||
|
||||
TORCH_INTERNAL_ASSERT(inputs.size() == 3);
|
||||
auto output_dims = c10::fmap<DimArg>(outputShape);
|
||||
|
||||
// We do not handle None for dims (input 1) because that is supposed to
|
||||
// be deprecated.
|
||||
|
|
@ -48,10 +47,10 @@ Tensor computeSoftmax(
|
|||
int64_t rank = valueShape(inputs[0]).size();
|
||||
size_t softmax_dim =
|
||||
normalizeAndCheckIndex(c10::get<int64_t>(inputs[1]), rank);
|
||||
std::vector<DimArg> non_softmax_dims;
|
||||
for (size_t i = 0; i < output_dims.size(); ++i) {
|
||||
std::vector<ExprHandle> non_softmax_dims;
|
||||
for (size_t i = 0; i < outputShape.size(); ++i) {
|
||||
if (i != softmax_dim) {
|
||||
non_softmax_dims.push_back(output_dims[i]);
|
||||
non_softmax_dims.push_back(outputShape[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -108,9 +107,9 @@ Tensor computeSoftmax(
|
|||
return tensorOrConstant(
|
||||
inputs[0], move_softmax_dim_index_to_pos(indices));
|
||||
},
|
||||
{output_dims[softmax_dim]});
|
||||
{outputShape[softmax_dim]});
|
||||
auto e =
|
||||
Compute("aten_softmax_exp", output_dims, [&](ParameterList& indices) {
|
||||
Compute("aten_softmax_exp", outputShape, [&](ParameterList& indices) {
|
||||
auto inp = tensorOrConstant(
|
||||
inputs[0], convert_indices_to_expr_handle(indices));
|
||||
return exp(inp - max.load(remove_softmax_dim_index(indices)));
|
||||
|
|
@ -122,10 +121,10 @@ Tensor computeSoftmax(
|
|||
[&](ParameterList& indices) {
|
||||
return e.load(move_softmax_dim_index_to_pos(indices));
|
||||
},
|
||||
{output_dims[softmax_dim]});
|
||||
{outputShape[softmax_dim]});
|
||||
if (!log_softmax) {
|
||||
auto result =
|
||||
Compute("aten_softmax", output_dims, [&](ParameterList& indices) {
|
||||
Compute("aten_softmax", outputShape, [&](ParameterList& indices) {
|
||||
return e.load(indices) / sum.load(remove_softmax_dim_index(indices));
|
||||
});
|
||||
return Tensor(
|
||||
|
|
@ -139,7 +138,7 @@ Tensor computeSoftmax(
|
|||
return log(sum.load(indices));
|
||||
});
|
||||
auto result =
|
||||
Compute("aten_log_softmax", output_dims, [&](ParameterList& indices) {
|
||||
Compute("aten_log_softmax", outputShape, [&](ParameterList& indices) {
|
||||
auto inp = tensorOrConstant(
|
||||
inputs[0], convert_indices_to_expr_handle(indices));
|
||||
auto non_softmax_indices = remove_softmax_dim_index(indices);
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/tensorexpr/dim_arg.h>
|
||||
#include <torch/csrc/jit/tensorexpr/expr.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
#include <c10/util/Logging.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/jit/tensorexpr/dim_arg.h>
|
||||
#include <torch/csrc/jit/tensorexpr/reduction.h>
|
||||
|
||||
namespace torch {
|
||||
|
|
@ -51,11 +50,9 @@ StmtPtr Tensor::constructStmt(
|
|||
|
||||
Tensor Compute(
|
||||
const std::string& name,
|
||||
const std::vector<DimArg>& dim_args,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
const std::function<ExprHandle(const std::vector<VarHandle>&)>& body_func) {
|
||||
std::vector<ExprHandle> dims;
|
||||
std::vector<VarHandle> args;
|
||||
unpack_dim_args(dim_args, &dims, &args);
|
||||
std::vector<VarHandle> args = create_index_vars(dims);
|
||||
ExprHandle body = body_func(args);
|
||||
BufHandle buf = Buf::make(name, dims, body.dtype());
|
||||
return Tensor(buf, args, body);
|
||||
|
|
@ -63,15 +60,13 @@ Tensor Compute(
|
|||
|
||||
Tensor Compute(
|
||||
const std::string& name,
|
||||
const std::vector<DimArg>& dim_args,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
const std::function<ExprHandle(const VarHandle&)>& body_func) {
|
||||
if (dim_args.size() != 1) {
|
||||
if (dims.size() != 1) {
|
||||
throw malformed_input("mismatch between body and arg size (1)");
|
||||
}
|
||||
|
||||
std::vector<ExprHandle> dims;
|
||||
std::vector<VarHandle> args;
|
||||
unpack_dim_args(dim_args, &dims, &args);
|
||||
std::vector<VarHandle> args = create_index_vars(dims);
|
||||
ExprHandle body = body_func(args[0]);
|
||||
BufHandle buf = Buf::make(name, dims, body.dtype());
|
||||
return Tensor(buf, args, body);
|
||||
|
|
@ -79,15 +74,13 @@ Tensor Compute(
|
|||
|
||||
Tensor Compute(
|
||||
const std::string& name,
|
||||
const std::vector<DimArg>& dim_args,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
const std::function<ExprHandle(const VarHandle&, const VarHandle&)>&
|
||||
body_func) {
|
||||
if (dim_args.size() != 2) {
|
||||
if (dims.size() != 2) {
|
||||
throw malformed_input("mismatch between body and arg size (2)");
|
||||
}
|
||||
std::vector<ExprHandle> dims;
|
||||
std::vector<VarHandle> args;
|
||||
unpack_dim_args(dim_args, &dims, &args);
|
||||
std::vector<VarHandle> args = create_index_vars(dims);
|
||||
ExprHandle body = body_func(args[0], args[1]);
|
||||
BufHandle buf = Buf::make(name, dims, body.dtype());
|
||||
return Tensor(buf, args, body);
|
||||
|
|
@ -95,16 +88,14 @@ Tensor Compute(
|
|||
|
||||
Tensor Compute(
|
||||
const std::string& name,
|
||||
const std::vector<DimArg>& dim_args,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
const std::function<
|
||||
ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>&
|
||||
body_func) {
|
||||
if (dim_args.size() != 3) {
|
||||
if (dims.size() != 3) {
|
||||
throw malformed_input("mismatch between body and arg size (3)");
|
||||
}
|
||||
std::vector<ExprHandle> dims;
|
||||
std::vector<VarHandle> args;
|
||||
unpack_dim_args(dim_args, &dims, &args);
|
||||
std::vector<VarHandle> args = create_index_vars(dims);
|
||||
ExprHandle body = body_func(args[0], args[1], args[2]);
|
||||
BufHandle buf = Buf::make(name, dims, body.dtype());
|
||||
return Tensor(buf, args, body);
|
||||
|
|
@ -112,18 +103,16 @@ Tensor Compute(
|
|||
|
||||
Tensor Compute(
|
||||
const std::string& name,
|
||||
const std::vector<DimArg>& dim_args,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
const std::function<ExprHandle(
|
||||
const VarHandle&,
|
||||
const VarHandle&,
|
||||
const VarHandle&,
|
||||
const VarHandle&)>& body_func) {
|
||||
if (dim_args.size() != 4) {
|
||||
if (dims.size() != 4) {
|
||||
throw malformed_input("mismatch between body and arg size (4)");
|
||||
}
|
||||
std::vector<ExprHandle> dims;
|
||||
std::vector<VarHandle> args;
|
||||
unpack_dim_args(dim_args, &dims, &args);
|
||||
std::vector<VarHandle> args = create_index_vars(dims);
|
||||
ExprHandle body = body_func(args[0], args[1], args[2], args[3]);
|
||||
BufHandle buf = Buf::make(name, dims, body.dtype());
|
||||
return Tensor(buf, args, body);
|
||||
|
|
@ -131,30 +120,30 @@ Tensor Compute(
|
|||
|
||||
Tensor Reduce(
|
||||
const std::string& name,
|
||||
const std::vector<DimArg>& dim_args,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
const Reducer& reducer,
|
||||
const BufHandle& buffer,
|
||||
const std::vector<DimArg>& reduce_args) {
|
||||
const std::vector<ExprHandle>& reduce_dims) {
|
||||
return Reduce(
|
||||
name,
|
||||
dim_args,
|
||||
dims,
|
||||
reducer,
|
||||
[&](ParameterList& p) { return buffer.load(p); },
|
||||
reduce_args);
|
||||
reduce_dims);
|
||||
}
|
||||
|
||||
Tensor Reduce(
|
||||
const std::string& name,
|
||||
const std::vector<DimArg>& dim_args,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
const Reducer& reducer,
|
||||
Tensor tensor,
|
||||
const std::vector<DimArg>& reduce_args) {
|
||||
const std::vector<ExprHandle>& reduce_dims) {
|
||||
return Reduce(
|
||||
name,
|
||||
dim_args,
|
||||
dims,
|
||||
reducer,
|
||||
[&](ParameterList& p) { return tensor.load(p); },
|
||||
reduce_args);
|
||||
reduce_dims);
|
||||
}
|
||||
|
||||
} // namespace tensorexpr
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
#include <functional>
|
||||
#include <vector>
|
||||
|
||||
#include <torch/csrc/jit/tensorexpr/dim_arg.h>
|
||||
#include <torch/csrc/jit/tensorexpr/expr.h>
|
||||
#include <torch/csrc/jit/tensorexpr/reduction.h>
|
||||
|
||||
|
|
@ -73,22 +72,22 @@ class TORCH_API Tensor {
|
|||
|
||||
TORCH_API Tensor Compute(
|
||||
const std::string& func_name,
|
||||
const std::vector<DimArg>& dim_args,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
const std::function<ExprHandle(const VarHandle&)>& body_func);
|
||||
TORCH_API Tensor Compute(
|
||||
const std::string& func_name,
|
||||
const std::vector<DimArg>& dim_args,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
const std::function<ExprHandle(const VarHandle&, const VarHandle&)>&
|
||||
body_func);
|
||||
TORCH_API Tensor Compute(
|
||||
const std::string& func_name,
|
||||
const std::vector<DimArg>& dim_args,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
const std::function<
|
||||
ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>&
|
||||
body_func);
|
||||
TORCH_API Tensor Compute(
|
||||
const std::string& func_name,
|
||||
const std::vector<DimArg>& dim_args,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
const std::function<ExprHandle(
|
||||
const VarHandle&,
|
||||
const VarHandle&,
|
||||
|
|
@ -96,40 +95,31 @@ TORCH_API Tensor Compute(
|
|||
const VarHandle&)>& body_func);
|
||||
TORCH_API Tensor Compute(
|
||||
const std::string& func_name,
|
||||
const std::vector<DimArg>& dim_args,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
const std::function<ExprHandle(const std::vector<VarHandle>&)>& body_func);
|
||||
|
||||
inline void unpack_dim_args(
|
||||
const std::vector<DimArg>& dim_args,
|
||||
std::vector<ExprHandle>* dims,
|
||||
std::vector<VarHandle>* vars) {
|
||||
dims->clear();
|
||||
vars->clear();
|
||||
for (const DimArg& dim_arg : dim_args) {
|
||||
ExprHandle expr = dim_arg.dim();
|
||||
dims->push_back(expr);
|
||||
vars->push_back(VarHandle(alloc<Var>(
|
||||
dim_arg.name_hint(),
|
||||
expr.dtype().scalar_type() == ScalarType::Long ? kLong : kInt)));
|
||||
inline std::vector<VarHandle> create_index_vars(
|
||||
const std::vector<ExprHandle>& dims) {
|
||||
std::vector<VarHandle> vars;
|
||||
vars.reserve(dims.size());
|
||||
for (const ExprHandle& dim : dims) {
|
||||
vars.push_back(VarHandle(alloc<Var>(
|
||||
"i", dim.dtype().scalar_type() == ScalarType::Long ? kLong : kInt)));
|
||||
}
|
||||
return vars;
|
||||
}
|
||||
|
||||
// Handle reductions over a Reducer and a body_func which produces values.
|
||||
template <typename InitFunc, typename BodyFunc>
|
||||
Tensor Reduce(
|
||||
const std::string& func_name,
|
||||
const std::vector<DimArg>& dim_args,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
const Reducer& reducer,
|
||||
const InitFunc& init_func,
|
||||
const BodyFunc& body_func,
|
||||
const std::vector<DimArg>& reduce_args) {
|
||||
std::vector<ExprHandle> dims;
|
||||
std::vector<VarHandle> vars;
|
||||
unpack_dim_args(dim_args, &dims, &vars);
|
||||
|
||||
std::vector<ExprHandle> reduce_dims;
|
||||
std::vector<VarHandle> reduce_vars;
|
||||
unpack_dim_args(reduce_args, &reduce_dims, &reduce_vars);
|
||||
const std::vector<ExprHandle>& reduce_dims) {
|
||||
std::vector<VarHandle> vars = create_index_vars(dims);
|
||||
std::vector<VarHandle> reduce_vars = create_index_vars(reduce_dims);
|
||||
|
||||
// If reduce_vars is empty, then it's not a reduction, but rather a simple
|
||||
// copy
|
||||
|
|
@ -155,45 +145,45 @@ Tensor Reduce(
|
|||
template <typename BodyFunc>
|
||||
Tensor Reduce(
|
||||
const std::string& func_name,
|
||||
const std::vector<DimArg>& dim_args,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
const Reducer& reducer,
|
||||
const BodyFunc& body_func,
|
||||
const std::vector<DimArg>& reduce_args) {
|
||||
const std::vector<ExprHandle>& reduce_dims) {
|
||||
return Reduce(
|
||||
func_name,
|
||||
dim_args,
|
||||
dims,
|
||||
reducer,
|
||||
[&](ParameterList p) { return ExprHandle(reducer.initializer()); },
|
||||
body_func,
|
||||
reduce_args);
|
||||
reduce_dims);
|
||||
}
|
||||
|
||||
// Overload which allows inline lambda functions for the body_func.
|
||||
template <typename BodyFunc>
|
||||
Tensor Reduce(
|
||||
const std::string& func_name,
|
||||
const std::vector<DimArg>& dim_args,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
const Reducer& reducer,
|
||||
const BodyFunc&& body_func,
|
||||
const std::vector<DimArg>& reduce_args) {
|
||||
return Reduce(func_name, dim_args, reducer, body_func, reduce_args);
|
||||
const std::vector<ExprHandle>& reduce_dims) {
|
||||
return Reduce(func_name, dims, reducer, body_func, reduce_dims);
|
||||
}
|
||||
|
||||
TORCH_API Tensor Reduce(
|
||||
const std::string& name,
|
||||
const std::vector<DimArg>& dim_args,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
const Reducer& reducer,
|
||||
const BufHandle& buffer,
|
||||
const std::vector<DimArg>& reduce_args);
|
||||
const std::vector<ExprHandle>& reduce_dims);
|
||||
|
||||
// Overload for the common case of all dimensions of a prevously Computed
|
||||
// Tensor.
|
||||
TORCH_API Tensor Reduce(
|
||||
const std::string& func_name,
|
||||
const std::vector<DimArg>& dim_args,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
const Reducer& reducer,
|
||||
Tensor tensor,
|
||||
const std::vector<DimArg>& reduce_args);
|
||||
const std::vector<ExprHandle>& reduce_dims);
|
||||
|
||||
template <typename... Ts>
|
||||
inline ExprHandle Tensor::load(const Ts&... ts) const {
|
||||
|
|
|
|||
|
|
@ -278,17 +278,10 @@ void initTensorExprBindings(PyObject* module) {
|
|||
self->set_src_value(value.node());
|
||||
});
|
||||
|
||||
py::class_<DimArg>(te, "DimArg")
|
||||
.def(py::init<const ExprHandle&>())
|
||||
.def(py::init<const ExprHandle&, const std::string&>());
|
||||
py::implicitly_convertible<ExprHandle, DimArg>();
|
||||
py::implicitly_convertible<int32_t, DimArg>();
|
||||
py::implicitly_convertible<int64_t, DimArg>();
|
||||
|
||||
te.def(
|
||||
"Compute",
|
||||
[](const std::string& func_name,
|
||||
const std::vector<DimArg>& dim_args,
|
||||
const std::vector<ExprHandle>& dim_args,
|
||||
py::function func) {
|
||||
if (dim_args.size() == 1) {
|
||||
return Compute(func_name, dim_args, [&func](const VarHandle& a) {
|
||||
|
|
@ -329,7 +322,7 @@ void initTensorExprBindings(PyObject* module) {
|
|||
te.def(
|
||||
"Compute2",
|
||||
[](const std::string& func_name,
|
||||
const std::vector<DimArg>& dim_args,
|
||||
const std::vector<ExprHandle>& dim_args,
|
||||
py::function func) {
|
||||
return Compute(
|
||||
func_name, dim_args, [&func](const std::vector<VarHandle>& dims) {
|
||||
|
|
@ -348,10 +341,10 @@ void initTensorExprBindings(PyObject* module) {
|
|||
te.def(
|
||||
"Reduce",
|
||||
[](const std::string& func_name,
|
||||
const std::vector<DimArg>& dim_args,
|
||||
const std::vector<ExprHandle>& dim_args,
|
||||
const Reducer& reducer,
|
||||
Tensor buffer,
|
||||
const std::vector<DimArg>& reduce_args) {
|
||||
const std::vector<ExprHandle>& reduce_args) {
|
||||
return Reduce(func_name, dim_args, reducer, buffer, reduce_args);
|
||||
},
|
||||
py::return_value_policy::reference);
|
||||
|
|
@ -359,34 +352,34 @@ void initTensorExprBindings(PyObject* module) {
|
|||
te.def(
|
||||
"Reduce",
|
||||
[](const std::string& func_name,
|
||||
const std::vector<DimArg>& dim_args,
|
||||
const std::vector<ExprHandle>& dim_args,
|
||||
const Reducer& reducer,
|
||||
const BufHandle& buffer,
|
||||
const std::vector<DimArg>& reduce_args) {
|
||||
const std::vector<ExprHandle>& reduce_args) {
|
||||
return Reduce(func_name, dim_args, reducer, buffer, reduce_args);
|
||||
},
|
||||
py::return_value_policy::reference);
|
||||
te.def(
|
||||
"Reduce",
|
||||
[](const std::string& func_name,
|
||||
const std::vector<DimArg>& dim_args,
|
||||
const std::vector<ExprHandle>& dim_args,
|
||||
const Reducer& reducer,
|
||||
const std::function<ExprHandle(const std::vector<VarHandle>&)>&
|
||||
body_func,
|
||||
const std::vector<DimArg>& reduce_args) {
|
||||
const std::vector<ExprHandle>& reduce_args) {
|
||||
return Reduce(func_name, dim_args, reducer, body_func, reduce_args);
|
||||
},
|
||||
py::return_value_policy::reference);
|
||||
te.def(
|
||||
"Reduce",
|
||||
[](const std::string& func_name,
|
||||
const std::vector<DimArg>& dim_args,
|
||||
const std::vector<ExprHandle>& dim_args,
|
||||
const Reducer& reducer,
|
||||
const std::function<ExprHandle(const std::vector<VarHandle>&)>&
|
||||
init_func,
|
||||
const std::function<ExprHandle(const std::vector<VarHandle>&)>&
|
||||
body_func,
|
||||
const std::vector<DimArg>& reduce_args) {
|
||||
const std::vector<ExprHandle>& reduce_args) {
|
||||
return Reduce(func_name, dim_args, reducer, body_func, reduce_args);
|
||||
},
|
||||
py::return_value_policy::reference);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user