remove tuple logic in constant propagation (#31840)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31840

The next PR in this stack makes tuples insertable as constants, so we can remove special handling of tuples in constant propagation.

Test Plan: Imported from OSS

Differential Revision: D19439515

Pulled By: eellison

fbshipit-source-id: c58f153157f1d4eee4c1242decc4f36e41c1aa05
This commit is contained in:
Elias Ellison 2020-01-22 12:09:46 -08:00 committed by Facebook Github Bot
parent b01d824a78
commit 69492ad6ac
3 changed files with 1 additions and 140 deletions

View File

@ -1,73 +0,0 @@
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/irparser.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/testing/file_check.h>
#include "test/cpp/jit/test_base.h"
#include "torch/csrc/jit/custom_operator.h"
#include <sstream>
#include <string>
namespace torch {
namespace jit {
namespace {
inline c10::OperatorOptions _aliasAnalysisFromSchema() {
c10::OperatorOptions result;
result.setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA);
return result;
}
} // namespace
void testConstantPropagation() {
{
auto graph = std::make_shared<Graph>();
script::parseIR(
R"IR(
graph():
%1 : int = prim::Constant[value=1]()
%0 : int = prim::Constant[value=0]()
%x : (int, int) = prim::TupleConstruct(%0, %1)
%y : int = prim::TupleIndex(%x, %0)
%5 : int = aten::add(%y, %y)
return (%5)
)IR",
&*graph);
// optimize through tuple construct and indexing
ConstantPropagation(graph);
testing::FileCheck()
.check("graph")
->check_next("prim::Constant[value=0]")
->check_next("return")
->run(*graph);
}
{
auto graph = std::make_shared<Graph>();
script::parseIR(
R"IR(
graph():
%10 : None = prim::Constant()
%7 : int = prim::Constant[value=0]()
%1 : int = prim::Constant[value=1]()
%0 : int = prim::Constant[value=3]()
%x : (int, int) = prim::TupleConstruct(%0, %1)
%y : (int, (int, int)) = prim::TupleConstruct(%1, %x)
%6 : (int, int) = prim::TupleIndex(%y, %1)
%z : int = prim::TupleIndex(%6, %7)
%9 : int = aten::add(%z, %z)
%ign = prim::Print(%y, %9)
return (%10) )IR",
&*graph);
ConstantPropagation(graph);
// The index should be optimized away, with a computed value of 6,
// and the TupleConstructs should still remain
testing::FileCheck()
.check_count("TupleConstruct", 2)
->check_not("TupleIndex")
->check("value=6")
->run(*graph);
}
}
} // namespace jit
} // namespace torch

View File

@ -42,7 +42,6 @@ namespace jit {
_(MemoryDAG) \
_(IRParser) \
_(ConstantPooling) \
_(ConstantPropagation) \
_(NetDefConverter) \
_(THNNConv) \
_(ATenNativeBatchNorm) \

View File

@ -48,13 +48,6 @@ std::unordered_set<Symbol> skip_list = {
// where the constant tensor would be large but cheap to create.
};
std::unordered_set<Symbol> tuple_ops = {
prim::TupleSlice,
prim::TupleIndex,
prim::TupleUnpack,
prim::TupleConstruct,
};
struct ConstantPropagator {
// Runs constant propagation with an aliasing db and checks if inputs or
// outputs might be mutated in the graph
@ -82,20 +75,11 @@ struct ConstantPropagator {
}
}
void pushIValue(Value* v, Stack& stack) {
if (tuples.count(v)) {
const auto& ival = tuples[v];
stack.push_back(ival);
} else {
stack.push_back(*toIValue(v));
}
}
std::vector<IValue> runNode(Node* n) {
auto op = getOperation(n);
Stack stack;
for (auto input : n->inputs()) {
pushIValue(input, stack);
stack.push_back(*toIValue(input));
}
op(stack);
auto var_outputs = fmap(stack, [&](IValue v) -> IValue {
@ -117,34 +101,6 @@ struct ConstantPropagator {
return var_outputs;
}
// Tuples are not representable as constants, however
// we can try to insert each tuple element and then create a TupleConstruct
// from the elements
Value* tryInsertTuple(const IValue& tuple, Value* tuple_to_replace) {
auto type = tuple_to_replace->type();
TupleTypePtr tup_type;
if (auto opt = type->cast<OptionalType>()) {
tup_type = opt->getElementType()->expect<TupleType>();
} else {
tup_type = type->expect<TupleType>();
}
auto type_elements = tup_type->elements();
const auto& tuple_elements = tuple.toTuple()->elements();
std::vector<Value*> inputs;
for (size_t i = 0; i < type_elements.size(); ++i) {
auto inp = tryInsertConstant(*graph_, tuple_elements[i]);
if (inp) {
inputs.push_back(*inp);
} else {
return nullptr;
}
}
auto new_tuple = graph_->insertNode(graph_->createTuple(inputs));
tuple_to_replace->replaceAllUsesWith(new_tuple->output());
new_tuple->output()->copyMetadata(tuple_to_replace);
return new_tuple->output();
}
void propagateNode(Node* n) {
std::vector<IValue> outputs;
try {
@ -168,19 +124,6 @@ struct ConstantPropagator {
(*new_output)->setType(n->outputs()[i]->type());
}
n->outputs()[i]->replaceAllUsesWith(*new_output);
} else if (outputs[i].isTuple()) {
// we save the new Tuple ivalue in case it is used in an op that
// forwards tuples later in the graph, such as a Tuple index
auto tuple_val = n->outputs()[i];
if (auto new_tup = tryInsertTuple(outputs[i], tuple_val)) {
GRAPH_UPDATE(
"Folding tuple %",
n->outputs()[i]->debugName(),
" with ",
getHeader(new_tup->node()));
tuple_val = new_tup;
}
tuples[tuple_val] = std::move(outputs[i]);
}
// If we cannot insert the IValue as a constant, give up replacing the
// node and let DCE remove it
@ -322,12 +265,6 @@ struct ConstantPropagator {
})) {
return true;
}
if (tuple_ops.count(n->kind())) {
return (
std::all_of(n->inputs().begin(), n->inputs().end(), [&](Value* v) {
return v->node()->kind() == prim::Constant || tuples.count(v);
}));
}
return false;
};
@ -390,8 +327,6 @@ struct ConstantPropagator {
std::shared_ptr<Graph> graph_;
std::unique_ptr<AliasDb> aliasDb_;
// these are tuples which we know the computed IValue for
std::unordered_map<Value*, IValue> tuples;
};
} // anonymous namespace