[JIT] Add variadic stack op (#63578)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63578

Added a new op `prim::VarStack` and a pass that transforms instances of `aten::stack(list, dim)` into `prim::VarStack(list[0], ..., list[n], dim)`. Also provided a JIT interpreter implementation.

Most of the implementation/tests are the same as `prim::VarConcat`.

Test Plan: `buck test caffe2/test/cpp/jit:jit -- TestStackOpt`

Reviewed By: navahgar

Differential Revision: D30426232

fbshipit-source-id: 9829a7db6e0a5038c9b7528c43c25b0c221aa2ce
This commit is contained in:
Mike Iovine 2021-08-24 08:19:38 -07:00 committed by Facebook GitHub Bot
parent f4aff3a346
commit 1385f9fb12
7 changed files with 339 additions and 1 deletions

View File

@ -84,6 +84,7 @@ namespace c10 {
_(prim, NumToTensor) \
_(prim, Uninitialized) \
_(prim, VarConcat) \
_(prim, VarStack) \
_(prim, With) \
_(prim, Enter) \
_(prim, Exit) \

View File

@ -62,6 +62,7 @@ set(JIT_TEST_SRCS
${JIT_TEST_ROOT}/test_qualified_name.cpp
${JIT_TEST_ROOT}/test_save_load.cpp
${JIT_TEST_ROOT}/test_schema_matching.cpp
${JIT_TEST_ROOT}/test_stack_opt.cpp
${JIT_TEST_ROOT}/test_subgraph_matcher.cpp
${JIT_TEST_ROOT}/test_subgraph_rewriter.cpp
${JIT_TEST_ROOT}/test_subgraph_utils.cpp

View File

@ -0,0 +1,308 @@
#include <gtest/gtest.h>
#include <test/cpp/jit/test_utils.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/passes/variadic_ops.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/jit/testing/file_check.h>
namespace torch {
namespace jit {
TEST(StackOptTest, UseVariadicStack) {
auto graph = std::make_shared<Graph>();
const std::string input =
R"IR(
graph(%0: Float(56, 56, 56),
%1: Float(56, 56, 56),
%2: Float(56, 56, 56),
%3: Float(56, 56, 56),
%4: Float(56, 56, 56),
%5: Float(56, 56, 56)):
%10 : int = prim::Constant[value=0]()
%input : Tensor[] = prim::ListConstruct(%0, %1, %2, %3, %4, %5)
%stack : Float(5, 56, 56, 56) = aten::stack(%input, %10)
return (%stack)
)IR";
parseIR(input, graph.get());
std::vector<at::Tensor> inputs = {
at::rand({56, 56, 56}, at::kCPU),
at::rand({56, 56, 56}, at::kCPU),
at::rand({56, 56, 56}, at::kCPU),
at::rand({56, 56, 56}, at::kCPU),
at::rand({56, 56, 56}, at::kCPU),
at::rand({56, 56, 56}, at::kCPU)};
auto orig_outputs = runGraph(graph, inputs);
ASSERT_TRUE(UseVariadicStack(graph));
graph->lint();
auto opt_outputs = runGraph(graph, inputs);
ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
// After replacing `aten::stack` with `prim::VarStack` we should have the
// following graph:
//
// graph(%0 : ...,
// %1 : ...):
// %zero : int = prim:Constant[value=0]()
// %varstack : Tensor = prim::VarStack(%0, %1, %2, %3, %4, %5, %zero)
// return (%varstack)
testing::FileCheck()
.check_count("= prim::VarStack(", 1, /*exactly*/ true)
->check_count("= aten::stack(", 0, /*exactly*/ true)
->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
->run(*graph);
}
TEST(StackOptTest, UseVariadicStackReplaceMultiple) {
auto graph = std::make_shared<Graph>();
const std::string input =
R"IR(
graph(%0: Float(56, 56, 56),
%1: Float(56, 56, 56),
%2: Float(56, 56, 56),
%3: Float(56, 56, 56)):
%10 : int = prim::Constant[value=0]()
%input1 : Tensor[] = prim::ListConstruct(%0, %1)
%stack1 : Float(4, 56, 56, 56) = aten::stack(%input1, %10)
%input2 : Tensor[] = prim::ListConstruct(%2, %3)
%stack2 : Float(4, 56, 56, 56) = aten::stack(%input2, %10)
return (%stack1, %stack2)
)IR";
parseIR(input, graph.get());
std::vector<at::Tensor> inputs = {
at::rand({56, 56, 56}, at::kCPU),
at::rand({56, 56, 56}, at::kCPU),
at::rand({56, 56, 56}, at::kCPU),
at::rand({56, 56, 56}, at::kCPU)};
auto orig_outputs = runGraph(graph, inputs);
ASSERT_TRUE(UseVariadicStack(graph));
graph->lint();
auto opt_outputs = runGraph(graph, inputs);
ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
// After full stack optimization we should have the following graph:
//
// graph(%0 : ...,
// %1 : ...,
// %2 : ...,
// %3 : ....):
// %zero : int = prim:Constant[value=0]()
// %varcat1 : Tensor = prim::VarStack(%0, %1, %zero)
// %varcat2 : Tensor = prim::VarStack(%2, %3, %zero)
// return (%varcat1, %varcat2)
testing::FileCheck()
.check_count("= prim::VarStack(", 2, /*exactly*/ true)
->check_count("= aten::stack(", 0, /*exactly*/ true)
->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
->run(*graph);
}
TEST(StackOptTest, UseVariadicStackWithMultipleListUses) {
auto graph = std::make_shared<Graph>();
const std::string input =
R"IR(
graph(%0: Float(56, 56, 56),
%1: Float(56, 56, 56)):
%2 : int = prim::Constant[value=0]()
%input : Tensor[] = prim::ListConstruct(%0, %1)
%stack : Float(2, 56, 56, 56) = aten::stack(%input, %2)
return (%stack, %input)
)IR";
parseIR(input, graph.get());
std::vector<at::Tensor> inputs = {
at::rand({56, 56, 56}, at::kCPU), at::rand({56, 56, 56}, at::kCPU)};
auto orig_outputs = runGraph(graph, inputs);
ASSERT_TRUE(UseVariadicStack(graph));
graph->lint();
auto opt_outputs = runGraph(graph, inputs);
ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
// After replacing `aten::stack` with `prim::VarStack` we should have the
// following graph:
//
// graph(%0 : ...,
// %1 : ...):
// %zero : int = prim:Constant[value=0]()
// %input : Tensor[] = prim::ListConstruct(%0, %1)
// %varcat : Tensor = prim::VarStack(%0, %1, %zero)
// return (%varcat, %input)
testing::FileCheck()
.check_count("= prim::ListConstruct(", 1, /*exactly*/ true)
->check_count("= prim::VarStack(", 1, /*exactly*/ true)
->check_count("= aten::stack(", 0, /*exactly*/ true)
->run(*graph);
}
TEST(StackOptTest, UseVariadicStackWithListMutationAfterCat) {
auto graph = std::make_shared<Graph>();
const std::string input =
R"IR(
graph(%0: Float(56, 56, 56),
%1: Float(56, 56, 56),
%2: Float(56, 56, 56)):
%10 : int = prim::Constant[value=0]()
%input : Tensor[] = prim::ListConstruct(%0, %1)
%stack : Float(3, 56, 56, 56) = aten::stack(%input, %10)
%11 : Tensor = aten::append(%input, %2)
return (%stack, %input)
)IR";
parseIR(input, graph.get());
std::vector<at::Tensor> inputs = {
at::rand({56, 56, 56}, at::kCPU),
at::rand({56, 56, 56}, at::kCPU),
at::rand({56, 56, 56}, at::kCPU)};
auto orig_outputs = runGraph(graph, inputs);
ASSERT_TRUE(UseVariadicStack(graph));
graph->lint();
auto opt_outputs = runGraph(graph, inputs);
ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
// The input list to `aten::stack` is mutated only after `aten::stack` op. So,
// it should have been replaced with `prim::VarStack`. The transformed graph
// should look like the following:
//
// graph(%0 : ...,
// %1 : ...,
// %2 : ...):
// %3 : int = prim:Constant[value=0]()
// %4 : Tensor[] = prim::ListConstruct(%0, %1)
// %7 : Tensor = prim::VarStack(%0, %1, %3)
// %6 : Tensor = aten::append(%4, %2)
// return (%7, %4)
testing::FileCheck()
.check_count("= prim::ListConstruct(", 1, /*exactly*/ true)
->check_count("= prim::VarStack(", 1, /*exactly*/ true)
->check_count("= aten::stack(", 0, /*exactly*/ true)
->run(*graph);
}
TEST(StackOptTest, UseVariadicStackWithListMutationBeforeCat) {
auto graph = std::make_shared<Graph>();
const std::string input =
R"IR(
graph(%0: Float(56, 56, 56),
%1: Float(56, 56, 56),
%2: Float(56, 56, 56)):
%10 : int = prim::Constant[value=0]()
%input : Tensor[] = prim::ListConstruct(%0, %1)
%11 : Tensor = aten::append(%input, %2)
%stack : Float(3, 56, 56, 56) = aten::stack(%input, %10)
return (%stack)
)IR";
parseIR(input, graph.get());
std::vector<at::Tensor> inputs = {
at::rand({56, 56, 56}, at::kCPU),
at::rand({56, 56, 56}, at::kCPU),
at::rand({56, 56, 56}, at::kCPU)};
auto orig_outputs = runGraph(graph, inputs);
{
ASSERT_FALSE(UseVariadicStack(graph));
graph->lint();
auto opt_outputs = runGraph(graph, inputs);
ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
// No transformation should have happened since the `prim::ListConstruct` is
// mutated before `aten::stack`.
testing::FileCheck()
.check_count("= prim::ListConstruct(", 1, /*exactly*/ true)
->check_count("= aten::stack(", 1, /*exactly*/ true)
->check_count("= prim::VarStack(", 0, /*exactly*/ true)
->run(*graph);
}
{
ASSERT_TRUE(RemoveListMutationAndUseVariadicStack(graph));
graph->lint();
auto opt_outputs = runGraph(graph, inputs);
ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
// The mutation of the list must be removed and the `aten::stack` op must
// be replaced with the `prim::VarStack` op in the graph. The transformed
// graph should look like the following:
//
// graph(%0 : ...,
// %1 : ...,
// %2 : ...):
// %3 : int = prim:Constant[value=0]()
// %7 : Tensor = prim::VarStack(%0, %1, %2, %3)
// return (%7)
testing::FileCheck()
.check_count("= prim::VarStack(", 1, /*exactly*/ true)
->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
->check_count("= aten::stack(", 0, /*exactly*/ true)
->run(*graph);
}
}
TEST(StackOptTest, UseVariadicStackWithMultipleListMutations) {
auto graph = std::make_shared<Graph>();
const std::string input =
R"IR(
graph(%0: Float(56, 56, 56),
%1: Float(56, 56, 56),
%2: Float(56, 56, 56),
%3: Float(56, 56, 56),
%4: Float(56, 56, 56)):
%10 : int = prim::Constant[value=0]()
%input : Tensor[] = prim::ListConstruct(%0, %1)
%stack.1 : Float(5, 56, 56, 56) = aten::stack(%input, %10)
%11 : Tensor = aten::append(%input, %2)
%stack.2 : Float(5, 56, 56, 56) = aten::stack(%input, %10)
%12 : Tensor = aten::append(%input, %3)
%stack.3 : Float(5, 56, 56, 56) = aten::stack(%input, %10)
%13 : Tensor = aten::append(%input, %4)
%stack.4 : Float(5, 56, 56, 56) = aten::stack(%input, %10)
return (%stack.1, %stack.2, %stack.3, %stack.4)
)IR";
parseIR(input, graph.get());
std::vector<at::Tensor> inputs = {
at::rand({56, 56, 56}, at::kCPU),
at::rand({56, 56, 56}, at::kCPU),
at::rand({56, 56, 56}, at::kCPU),
at::rand({56, 56, 56}, at::kCPU),
at::rand({56, 56, 56}, at::kCPU)};
auto orig_outputs = runGraph(graph, inputs);
ASSERT_TRUE(RemoveListMutationAndUseVariadicStack(graph));
graph->lint();
auto opt_outputs = runGraph(graph, inputs);
ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
// All the mutations of the list must be removed and the `aten::stack` ops
// must be replaced with `prim::VarStack` ops in the graph. The transformed
// graph should look like the following:
//
// graph(%0 : ...,
// %1 : ...,
// %2 : ...,
// %3 : ...,
// %4 : ...):
// %10 : int = prim:Constant[value=0]()
// %5 : Tensor = prim::VarStack(%0, %1, %10)
// %6 : Tensor = prim::VarStack(%0, %1, %2, %10)
// %7 : Tensor = prim::VarStack(%0, %1, %2, %3, %10)
// %8 : Tensor = prim::VarStack(%0, %1, %2, %3, %4, %10)
// return (%5, %6, %7, %8)
testing::FileCheck()
.check_count("= prim::VarStack(", 4, /*exactly*/ true)
->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
->check_count("= aten::stack(", 0, /*exactly*/ true)
->run(*graph);
}
} // namespace jit
} // namespace torch

View File

@ -122,5 +122,14 @@ bool RemoveListMutationAndUseVariadicCat(const std::shared_ptr<Graph>& graph) {
return RemoveListMutationAndUseVariadicOp(graph, aten::cat, prim::VarConcat);
}
bool UseVariadicStack(const std::shared_ptr<Graph>& graph) {
return UseVariadicOp(graph, aten::stack, prim::VarStack);
}
bool RemoveListMutationAndUseVariadicStack(
const std::shared_ptr<Graph>& graph) {
return RemoveListMutationAndUseVariadicOp(graph, aten::stack, prim::VarStack);
}
} // namespace jit
} // namespace torch

View File

@ -12,5 +12,12 @@ TORCH_API bool UseVariadicCat(const std::shared_ptr<Graph>& graph);
TORCH_API bool RemoveListMutationAndUseVariadicCat(
const std::shared_ptr<Graph>& graph);
// Replaces the `aten::stack` ops in the given graph with variadic cat ops.
// Returns true if the graph is modified.
TORCH_API bool UseVariadicStack(const std::shared_ptr<Graph>& graph);
TORCH_API bool RemoveListMutationAndUseVariadicStack(
const std::shared_ptr<Graph>& graph);
} // namespace jit
} // namespace torch

View File

@ -776,6 +776,18 @@ RegisterOperators reg(
push(stack, at::cat(inputs, dim));
},
aliasAnalysisFromSchema()),
OperatorGenerator(
TORCH_SELECTIVE_SCHEMA("prim::VarStack(...) -> Tensor"),
[](Stack* stack) {
auto num_inputs = pop(stack).toInt();
auto dim = pop(stack).toInt();
std::vector<at::Tensor> inputs(num_inputs - 1);
for (int i = 0; i < num_inputs - 1; ++i) {
inputs[num_inputs - 2 - i] = pop(stack).toTensor();
}
push(stack, at::stack(inputs, dim));
},
aliasAnalysisFromSchema()),
OperatorGenerator(
TORCH_SELECTIVE_SCHEMA(
"aten::eq.enum(AnyEnumType a, AnyEnumType b) -> bool"),

View File

@ -214,7 +214,7 @@ std::function<void(ProcessedNode*)> getOutOfPlaceOperation(Node* n) {
// Returns true if the node represents an op with variadic arguments.
bool hasVarArgs(Node* n) {
if (n->kind() == prim::VarConcat) {
if (n->kind() == prim::VarConcat || n->kind() == prim::VarStack) {
return true;
}
return false;