[nnc] Modified the semantics of reorder in using permutation (#61085)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61085

Test Plan: Imported from OSS

Reviewed By: ZolotukhinM

Differential Revision: D29506679

Pulled By: navahgar

fbshipit-source-id: f674aedff8175b9947404fd2164a0b4f57a71e93
This commit is contained in:
Raghavan Raman 2021-07-15 10:25:09 -07:00 committed by Facebook GitHub Bot
parent 7177509380
commit 2908d3eb45
3 changed files with 12 additions and 11 deletions

View File

@ -6364,14 +6364,14 @@ TEST(LoopNest, reorderNestedLoops3D) {
auto forI = For::make(i, 0, 20, forJ);
auto par = Block::make({forI});
auto reordered = LoopNest::reorder({forI, forJ, forK}, {2, 1, 0});
auto reordered = LoopNest::reorder({forI, forJ, forK}, {2, 0, 1});
ASSERT_EQ(reordered[0], forK);
ASSERT_EQ(reordered[1], forJ);
ASSERT_EQ(reordered[2], forI);
ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forK, forJ, forI}));
ASSERT_EQ(reordered[1], forI);
ASSERT_EQ(reordered[2], forJ);
ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forK, forI, forJ}));
ASSERT_EQ(forK->get_parent(), par);
ASSERT_EQ(store->get_parent(), forI->body());
ASSERT_EQ(store->get_parent(), forJ->body());
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
@ -6403,13 +6403,13 @@ TEST(LoopNest, reorderNestedLoops4D) {
auto forI = For::make(i, 0, 20, forJ);
auto par = Block::make({forI});
auto reordered = LoopNest::reorder({forI, forJ, forK, forL}, {2, 3, 0, 1});
auto reordered = LoopNest::reorder({forI, forJ, forK, forL}, {2, 0, 3, 1});
ASSERT_EQ(reordered[0], forK);
ASSERT_EQ(reordered[1], forL);
ASSERT_EQ(reordered[2], forI);
ASSERT_EQ(reordered[1], forI);
ASSERT_EQ(reordered[2], forL);
ASSERT_EQ(reordered[3], forJ);
ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forK, forL, forI, forJ}));
ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forK, forI, forL, forJ}));
ASSERT_EQ(forK->get_parent(), par);
ASSERT_EQ(store->get_parent(), forJ->body());
}

View File

@ -1956,7 +1956,7 @@ std::vector<For*> LoopNest::reorder(
// Reorder the loops according to the permutation.
std::vector<For*> result(loops.size());
for (size_t i = 0; i < loops.size(); ++i) {
result[permutation[i]] = loops[i];
result[i] = loops[permutation[i]];
}
// Remove the bodies from all the loops.

View File

@ -295,7 +295,8 @@ class TORCH_API LoopNest {
static void reorderAxis(For* a, For* b);
// Reorder the given list of loops according to the permutation specified.
// Here permutation[i] represents the location of the loop i in the result.
// Here `permutation[i]` represents the position of the loop in the input
// which will end up at position `i` after the reorder.
//
// For example, consider the following code:
// for p