pytorch/torch/csrc/jit/passes/lower_tuples.cpp
FFFrog af0bc75460 Remove deprecated alias macro(1/3) (#137556)
**Detailed Descriptions:**
- Remove AT_ERROR Macro

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137556
Approved by: https://github.com/ezyang
2024-10-21 17:32:32 +00:00

342 lines
11 KiB
C++

#include <torch/csrc/jit/passes/lower_tuples.h>
#include <ATen/core/functional.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/ir/constants.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <utility>
namespace torch::jit {
namespace {
// operators where we expect to find tuples as inputs/outputs
// this is to assert we are only doing modifications when we know
// we can flatten tuples
std::unordered_set<Symbol> supported_ops = {
prim::If,
prim::Loop,
prim::Uninitialized,
prim::TupleUnpack,
prim::TupleConstruct,
prim::TupleIndex,
prim::TupleSlice,
prim::Param,
prim::Return,
prim::PythonOp,
aten::format,
prim::Uninitialized,
aten::__getitem__};
// Flatten block inputs and insert a tuple construct in the block
static void flattenTupleInLoopParams(Node* n, size_t index) {
auto input = n->inputs().at(index);
TupleTypePtr tt = input->type()->cast<TupleType>();
TORCH_INTERNAL_ASSERT(tt);
Block* block = n->blocks().at(0);
Node* block_node = n;
auto new_construct_node =
block->prependNode(block->owningGraph()->create(prim::TupleConstruct));
for (size_t j = 0; j < tt->elements().size(); ++j) {
auto new_block_in = block->insertInput(index + j);
new_construct_node->addInput(new_block_in);
block_node->insertInput(index + j + 1, input->node()->inputs().at(j));
}
new_construct_node->output()->setType(block->inputs().at(index - 1)->type());
new_construct_node->copyMetadata(n);
block->inputs().at(index - 1)->replaceAllUsesWith(
new_construct_node->output());
block->eraseInput(index - 1);
block_node->removeInput(index);
}
// Flatten tuple outputs of the block node and append a TupleConstruct
// node after the block node if there is an outer block.
static void flattenTupleInBlockReturn(Node* n, size_t index) {
auto input = n->inputs().at(index);
Block* block = n->owningBlock();
Node* block_node = block->owningNode();
Node* new_construct_node = nullptr;
TupleTypePtr tt = input->type()->cast<TupleType>();
TORCH_INTERNAL_ASSERT(tt);
// 1- Add flattened tuple to block outputs
for (size_t j = 0; j < tt->elements().size(); ++j) {
block->insertOutput(index + j + 1, input->node()->inputs().at(j));
}
block->eraseOutput(index);
if (block_node == nullptr)
return;
// 2- For uses of the block node in the outer block,
// flatten the blocknode outputs and insert a tuple construct
// to replace that.
// Loop block has an extra element (iter counter)
if (block_node->kind() == prim::Loop)
index = index - 1;
auto tuple_output = block_node->outputs().at(index);
// When node has multiple blocks, do not flatten outputs on the second block
// again
if (!(tuple_output->type()->cast<TupleType>()))
return;
new_construct_node = block->owningGraph()->create(prim::TupleConstruct);
new_construct_node->insertAfter(block_node);
for (size_t j = 0; j < tt->elements().size(); ++j) {
auto new_block_out = block_node->insertOutput(index + j + 1);
new_construct_node->addInput(new_block_out);
}
// Replace the block node with the new TupleConstruct node
new_construct_node->output()->setType(tuple_output->type());
new_construct_node->copyMetadata(block_node);
tuple_output->replaceAllUsesWith(new_construct_node->output());
block_node->eraseOutput(index);
}
void removeTupleNodes(Node* n, bool must_remove_tuples) {
if (n->kind() != prim::TupleUnpack && n->kind() != prim::TupleIndex &&
n->kind() != prim::TupleSlice) {
return;
}
// tuple index has two inputs, tuple and index
auto construct_node = n->inputs().at(0)->node();
if (construct_node->kind() != prim::TupleConstruct) {
if (must_remove_tuples) {
TORCH_CHECK(
false, n->kind().toQualString(), " not matched to tuple construct");
}
return;
}
if (n->kind() == prim::TupleUnpack) {
for (size_t i = 0; i < n->outputs().size(); ++i) {
n->outputs()[i]->replaceAllUsesWith(construct_node->inputs().at(i));
}
} else if (n->kind() == prim::TupleIndex) {
auto idx = n->inputs().at(1);
auto maybe_int = constant_as<int64_t>(idx);
if (!maybe_int) {
if (must_remove_tuples) {
TORCH_CHECK(
false, n->sourceRange(), "tuple index with non-constant index");
}
return;
}
auto int_idx = *maybe_int;
size_t len = construct_node->output()->type()->containedTypes().size();
if (int_idx < 0) {
int_idx += len;
}
// currently, we allow non-constant tuple index if the tuple is of one type.
// so we need to check bounds here
if (int_idx >= 0 && static_cast<size_t>(int_idx) < len) {
n->output()->replaceAllUsesWith(construct_node->inputs().at(int_idx));
}
} else if (n->kind() == prim::TupleSlice) {
std::vector<Value*> values;
int64_t beg = n->i(attr::beg);
int64_t end = n->i(attr::end);
for (int64_t i = beg; i < end; i += 1) {
values.push_back(construct_node->inputs().at(i));
}
auto graph = n->owningGraph();
auto tuple_out = graph->createTuple(values);
tuple_out->copyMetadata(n);
WithInsertPoint insert(n);
graph->insertNode(tuple_out);
n->output()->replaceAllUsesWith(tuple_out->output());
}
}
} // anonymous namespace
static void LowerAllTuples(Block* block);
static void RemoveTupleConstants(Node* n) {
if (!(n->kind() == prim::Constant &&
n->output()->type()->cast<TupleType>())) {
return;
}
auto g = n->owningGraph();
auto tuple = toIValue(n->output()).value().toTuple();
const auto& tuple_elements = tuple->elements();
WithInsertPoint insert(n);
std::vector<Value*> elements;
for (const auto& elem : tuple_elements) {
auto constant = insertConstant(*n->owningGraph(), elem);
elements.push_back(constant);
}
auto tuple_type = n->output()->type()->expect<TupleType>();
auto tuple_construct = g->insertNode(n->owningGraph()->createTuple(
elements, tuple_type->schema() ? std::move(tuple_type) : nullptr));
tuple_construct->copyMetadata(n);
// insert the tuple first before recursing on its elements, so that its
// elements will have a use
for (Value* elem : elements) {
RemoveTupleConstants(elem->node());
}
n->replaceAllUsesWith(tuple_construct);
}
static void flattenInputs(Node* n, Node* insert_point) {
// flatten the input list op(a, tup, b) --> op(a, t0, t1, b)
for (size_t i = 0; i < n->inputs().size();) {
auto input = n->inputs()[i];
if (TupleTypePtr tt = input->type()->cast<TupleType>()) {
TORCH_CHECK(
(input->node()->kind() == prim::TupleConstruct),
"tuple use not matched to tuple construct. Instead found: ",
n->kind().toQualString());
if (supported_ops.count(n->kind()) > 0) {
if (n->kind() == prim::Loop) {
// This function supports all node types with blocks that take tuple
// inputs.
flattenTupleInLoopParams(n, i);
} else if (n->kind() == prim::Return) {
flattenTupleInBlockReturn(n, i);
} else {
for (size_t j = 0; j < tt->elements().size(); ++j) {
n->insertInput(i + 1 + j, input->node()->inputs().at(j));
}
n->removeInput(i);
}
// note: no update to i
// since tuples might be nested we need to recursively scan
// the new flattened inputs
} else {
TORCH_WARN(
"tuple appears in op inputs, but this op does not forward tuples, ",
"unsupported kind: ",
n->kind().toQualString());
++i;
}
} else {
++i;
}
}
}
static void flattenOutputs(Node* n, Node* insert_point) {
// flatten the outputs list
auto& graph = *n->owningGraph();
for (size_t i = 0; i < n->outputs().size();) {
Value* output = n->outputs()[i];
if (!output->hasUses()) {
++i;
continue;
}
// (a, b, tup, c) -> (a, b, t0, t1, c)
// and:
// tup = (t0, t1)
// is placed at the current insertion point
if (TupleTypePtr tt = output->type()->cast<TupleType>()) {
if (supported_ops.count(n->kind()) > 0) {
for (const auto j : c10::irange(tt->elements().size())) {
n->insertOutput(i + 1 + j)->setType(tt->elements()[j]);
}
auto new_tup =
graph.createTuple(n->outputs().slice(i + 1, tt->elements().size()));
new_tup->copyMetadata(n);
new_tup->insertBefore(insert_point);
insert_point = new_tup;
output->replaceAllUsesWith(new_tup->output());
n->eraseOutput(i);
// note: no update to i to handle nested tuples
} else {
TORCH_WARN(
"tuple appears in the op outputs, but this op does not forward tuples, ",
"unsupported kind: ",
n->kind().toQualString());
++i;
}
} else {
++i;
}
}
}
static void VisitNode(Node* n, Node* insert_point) {
// tuple construction operators will become dead when the unpacks are replaced
if (n->kind() == prim::TupleConstruct) {
return;
}
// note: changing the second argument to false changes this pass from a
// complete lowering pass to one that removes tuples when possible. When
// tuples are first-class in the interpreter, we should still run this pass to
// remove extraneous uses
if (n->kind() == prim::TupleUnpack || n->kind() == prim::TupleIndex ||
n->kind() == prim::TupleSlice) {
removeTupleNodes(n, /*must_remove_tuples*/ true);
return;
}
flattenInputs(n, insert_point);
for (auto b : n->blocks()) {
LowerAllTuples(b);
}
flattenOutputs(n, insert_point);
}
static void LowerAllTuples(Block* block) {
// tuples in parameter lists of a block behave exactly the same as
// _outputs_ of normal instructions, since the param_node represents the
// parameters as outputs, we can handle it by simply visiting the node
VisitNode(block->param_node(), *block->nodes().begin());
for (auto it = block->nodes().begin(), end = block->nodes().end();
it != end;) {
auto n = *it++;
RemoveTupleConstants(n);
VisitNode(n, *it);
}
// tuples in return lists of blocks behave exactly the same as
// _inputs_ of normal instructions, so we can use VisitNode here as well
// insert_point is null because it will never be used since return nodes
// have no outputs
VisitNode(block->return_node(), nullptr);
}
static void EnsureNoTuples(ArrayRef<Value*> values) {
for (Value* v : values) {
TORCH_CHECK(
v->type()->kind() != TypeKind::TupleType, "Couldn't lower all tuples.");
}
}
static void EnsureNoTuples(Block* block) {
for (Node* n : block->nodes()) {
for (Block* b : n->blocks()) {
EnsureNoTuples(b);
}
EnsureNoTuples(n->outputs());
}
}
void LowerAllTuples(const std::shared_ptr<Graph>& graph) {
LowerAllTuples(graph->block());
GRAPH_DUMP("After LowerAllTuples: ", graph);
EliminateDeadCode(graph->block());
EnsureNoTuples(graph->block());
}
void LowerSimpleTuples(Block* block) {
for (auto n : block->nodes()) {
removeTupleNodes(n, /*must_remove_tuples*/ false);
for (auto b : n->blocks()) {
LowerSimpleTuples(b);
}
}
}
void LowerSimpleTuples(const std::shared_ptr<Graph>& graph) {
LowerSimpleTuples(graph->block());
GRAPH_DUMP("After LowerSimpleTuples: ", graph);
EliminateDeadCode(graph);
}
} // namespace torch::jit