mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
remove tuple logic in constant propagation (#31840)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/31840 The next PR in this stack makes tuples insertable as constants, so we can remove special handling of tuples in constant propagation. Test Plan: Imported from OSS Differential Revision: D19439515 Pulled By: eellison fbshipit-source-id: c58f153157f1d4eee4c1242decc4f36e41c1aa05
This commit is contained in:
parent
b01d824a78
commit
69492ad6ac
|
|
@ -1,73 +0,0 @@
|
|||
#include <torch/csrc/jit/ir.h>
|
||||
#include <torch/csrc/jit/irparser.h>
|
||||
#include <torch/csrc/jit/passes/constant_pooling.h>
|
||||
#include <torch/csrc/jit/passes/constant_propagation.h>
|
||||
#include <torch/csrc/jit/testing/file_check.h>
|
||||
#include "test/cpp/jit/test_base.h"
|
||||
#include "torch/csrc/jit/custom_operator.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
namespace {
|
||||
inline c10::OperatorOptions _aliasAnalysisFromSchema() {
|
||||
c10::OperatorOptions result;
|
||||
result.setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA);
|
||||
return result;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void testConstantPropagation() {
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
script::parseIR(
|
||||
R"IR(
|
||||
graph():
|
||||
%1 : int = prim::Constant[value=1]()
|
||||
%0 : int = prim::Constant[value=0]()
|
||||
%x : (int, int) = prim::TupleConstruct(%0, %1)
|
||||
%y : int = prim::TupleIndex(%x, %0)
|
||||
%5 : int = aten::add(%y, %y)
|
||||
return (%5)
|
||||
)IR",
|
||||
&*graph);
|
||||
// optimize through tuple construct and indexing
|
||||
ConstantPropagation(graph);
|
||||
testing::FileCheck()
|
||||
.check("graph")
|
||||
->check_next("prim::Constant[value=0]")
|
||||
->check_next("return")
|
||||
->run(*graph);
|
||||
}
|
||||
{
|
||||
auto graph = std::make_shared<Graph>();
|
||||
script::parseIR(
|
||||
R"IR(
|
||||
graph():
|
||||
%10 : None = prim::Constant()
|
||||
%7 : int = prim::Constant[value=0]()
|
||||
%1 : int = prim::Constant[value=1]()
|
||||
%0 : int = prim::Constant[value=3]()
|
||||
%x : (int, int) = prim::TupleConstruct(%0, %1)
|
||||
%y : (int, (int, int)) = prim::TupleConstruct(%1, %x)
|
||||
%6 : (int, int) = prim::TupleIndex(%y, %1)
|
||||
%z : int = prim::TupleIndex(%6, %7)
|
||||
%9 : int = aten::add(%z, %z)
|
||||
%ign = prim::Print(%y, %9)
|
||||
return (%10) )IR",
|
||||
&*graph);
|
||||
ConstantPropagation(graph);
|
||||
// The index should be optimized away, with a computed value of 6,
|
||||
// and the TupleConstructs should still remain
|
||||
testing::FileCheck()
|
||||
.check_count("TupleConstruct", 2)
|
||||
->check_not("TupleIndex")
|
||||
->check("value=6")
|
||||
->run(*graph);
|
||||
}
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -42,7 +42,6 @@ namespace jit {
|
|||
_(MemoryDAG) \
|
||||
_(IRParser) \
|
||||
_(ConstantPooling) \
|
||||
_(ConstantPropagation) \
|
||||
_(NetDefConverter) \
|
||||
_(THNNConv) \
|
||||
_(ATenNativeBatchNorm) \
|
||||
|
|
|
|||
|
|
@ -48,13 +48,6 @@ std::unordered_set<Symbol> skip_list = {
|
|||
// where the constant tensor would be large but cheap to create.
|
||||
};
|
||||
|
||||
std::unordered_set<Symbol> tuple_ops = {
|
||||
prim::TupleSlice,
|
||||
prim::TupleIndex,
|
||||
prim::TupleUnpack,
|
||||
prim::TupleConstruct,
|
||||
};
|
||||
|
||||
struct ConstantPropagator {
|
||||
// Runs constant propagation with an aliasing db and checks if inputs or
|
||||
// outputs might be mutated in the graph
|
||||
|
|
@ -82,20 +75,11 @@ struct ConstantPropagator {
|
|||
}
|
||||
}
|
||||
|
||||
void pushIValue(Value* v, Stack& stack) {
|
||||
if (tuples.count(v)) {
|
||||
const auto& ival = tuples[v];
|
||||
stack.push_back(ival);
|
||||
} else {
|
||||
stack.push_back(*toIValue(v));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<IValue> runNode(Node* n) {
|
||||
auto op = getOperation(n);
|
||||
Stack stack;
|
||||
for (auto input : n->inputs()) {
|
||||
pushIValue(input, stack);
|
||||
stack.push_back(*toIValue(input));
|
||||
}
|
||||
op(stack);
|
||||
auto var_outputs = fmap(stack, [&](IValue v) -> IValue {
|
||||
|
|
@ -117,34 +101,6 @@ struct ConstantPropagator {
|
|||
return var_outputs;
|
||||
}
|
||||
|
||||
// Tuples are not representable as constants, however
|
||||
// we can try to insert each tuple element and then create a TupleConstruct
|
||||
// from the elements
|
||||
Value* tryInsertTuple(const IValue& tuple, Value* tuple_to_replace) {
|
||||
auto type = tuple_to_replace->type();
|
||||
TupleTypePtr tup_type;
|
||||
if (auto opt = type->cast<OptionalType>()) {
|
||||
tup_type = opt->getElementType()->expect<TupleType>();
|
||||
} else {
|
||||
tup_type = type->expect<TupleType>();
|
||||
}
|
||||
auto type_elements = tup_type->elements();
|
||||
const auto& tuple_elements = tuple.toTuple()->elements();
|
||||
std::vector<Value*> inputs;
|
||||
for (size_t i = 0; i < type_elements.size(); ++i) {
|
||||
auto inp = tryInsertConstant(*graph_, tuple_elements[i]);
|
||||
if (inp) {
|
||||
inputs.push_back(*inp);
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
auto new_tuple = graph_->insertNode(graph_->createTuple(inputs));
|
||||
tuple_to_replace->replaceAllUsesWith(new_tuple->output());
|
||||
new_tuple->output()->copyMetadata(tuple_to_replace);
|
||||
return new_tuple->output();
|
||||
}
|
||||
|
||||
void propagateNode(Node* n) {
|
||||
std::vector<IValue> outputs;
|
||||
try {
|
||||
|
|
@ -168,19 +124,6 @@ struct ConstantPropagator {
|
|||
(*new_output)->setType(n->outputs()[i]->type());
|
||||
}
|
||||
n->outputs()[i]->replaceAllUsesWith(*new_output);
|
||||
} else if (outputs[i].isTuple()) {
|
||||
// we save the new Tuple ivalue in case it is used in an op that
|
||||
// forwards tuples later in the graph, such as a Tuple index
|
||||
auto tuple_val = n->outputs()[i];
|
||||
if (auto new_tup = tryInsertTuple(outputs[i], tuple_val)) {
|
||||
GRAPH_UPDATE(
|
||||
"Folding tuple %",
|
||||
n->outputs()[i]->debugName(),
|
||||
" with ",
|
||||
getHeader(new_tup->node()));
|
||||
tuple_val = new_tup;
|
||||
}
|
||||
tuples[tuple_val] = std::move(outputs[i]);
|
||||
}
|
||||
// If we cannot insert the IValue as a constant, give up replacing the
|
||||
// node and let DCE remove it
|
||||
|
|
@ -322,12 +265,6 @@ struct ConstantPropagator {
|
|||
})) {
|
||||
return true;
|
||||
}
|
||||
if (tuple_ops.count(n->kind())) {
|
||||
return (
|
||||
std::all_of(n->inputs().begin(), n->inputs().end(), [&](Value* v) {
|
||||
return v->node()->kind() == prim::Constant || tuples.count(v);
|
||||
}));
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
|
|
@ -390,8 +327,6 @@ struct ConstantPropagator {
|
|||
|
||||
std::shared_ptr<Graph> graph_;
|
||||
std::unique_ptr<AliasDb> aliasDb_;
|
||||
// these are tuples which we know the computed IValue for
|
||||
std::unordered_map<Value*, IValue> tuples;
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user