[SR] Add BlockRunner and handle sub-blocks (#69834)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69834

* Modify the `StaticModule` constructor to handle index initialization for sub-blocks.
* Add a new class `StaticRuntimeBlockRunner`. This class is almost exactly like what we've been calling `StaticRuntime` up to this point, except that it does not own a `values_` array. All `StaticRuntimeBlockRunners` hold an unowned reference to a `values_` array owned by `StaticRuntime`. This is a useful abstraction for implementing control flow - it gives us a way for sub-blocks to look up values from surrounding scopes!
ghstack-source-id: 148086245

Test Plan: `buck test caffe2/benchmarks/static_runtime/...`

Reviewed By: d1jang

Differential Revision: D33028039

fbshipit-source-id: 4f01417bad51a0cf09b1680a518308da647be1f6
This commit is contained in:
Mike Iovine 2022-02-01 09:08:41 -08:00 committed by Facebook GitHub Bot
parent bdcc5f5f47
commit 3a9feffd92
5 changed files with 680 additions and 356 deletions

View File

@ -106,7 +106,8 @@ TEST(StaticModule, ValueGroup) {
torch::jit::StaticModule sm(input_graph); torch::jit::StaticModule sm(input_graph);
const Graph& graph = sm.graph(); const Graph& graph = sm.graph();
std::vector<const Node*> nodes(graph.nodes().begin(), graph.nodes().end()); std::vector<const Node*> nodes(graph.nodes().begin(), graph.nodes().end());
const auto& value_group = sm.value_group(); auto* root_block = sm.root_block();
const auto& value_group = sm.block_info(root_block).value_group();
std::vector<const Value*> expected_input_aliases{ std::vector<const Value*> expected_input_aliases{
graph.inputs()[0], graph.inputs()[1], nodes[0]->output()}; graph.inputs()[0], graph.inputs()[1], nodes[0]->output()};
@ -138,9 +139,11 @@ TEST(StaticModule, IsOptimizableContainerType_NonOptimizableInputs) {
auto sm = makeStaticModuleFromScript(src); auto sm = makeStaticModuleFromScript(src);
const auto& graph = sm.graph(); const auto& graph = sm.graph();
auto* root_block = sm.root_block();
const auto& block_info = sm.block_info(root_block);
for (const Node* n : graph.nodes()) { for (const Node* n : graph.nodes()) {
EXPECT_FALSE(sm.is_optimizable_container_type(n)); EXPECT_FALSE(block_info.node_is_optimizable_container_type(n));
} }
} }
@ -158,9 +161,11 @@ TEST(StaticModule, IsOptimizableContainerType_WrongType) {
auto sm = makeStaticModuleFromScript(src); auto sm = makeStaticModuleFromScript(src);
const auto& graph = sm.graph(); const auto& graph = sm.graph();
auto* root_block = sm.root_block();
const auto& block_info = sm.block_info(root_block);
for (const Node* n : graph.nodes()) { for (const Node* n : graph.nodes()) {
EXPECT_FALSE(sm.is_optimizable_container_type(n)); EXPECT_FALSE(block_info.node_is_optimizable_container_type(n));
} }
} }
@ -175,12 +180,14 @@ TEST(StaticModule, IsOptimizableContainerType_CanUseOutVariant) {
)JIT"; )JIT";
auto sm = makeStaticModuleFromScript(src); auto sm = makeStaticModuleFromScript(src);
const auto& graph = sm.graph(); const auto& graph = sm.graph();
auto* root_block = sm.root_block();
const auto& block_info = sm.block_info(root_block);
for (const Node* n : graph.nodes()) { for (const Node* n : graph.nodes()) {
if (n->kind() == c10::prim::ListConstruct) { if (n->kind() == c10::prim::ListConstruct) {
EXPECT_TRUE(sm.is_optimizable_container_type(n)); EXPECT_TRUE(block_info.node_is_optimizable_container_type(n));
} else { } else {
EXPECT_FALSE(sm.is_optimizable_container_type(n)); EXPECT_FALSE(block_info.node_is_optimizable_container_type(n));
} }
} }
} }
@ -1050,7 +1057,8 @@ TEST(ManagedTensorRanges, NoAliases) {
auto* z = vmap["z"]; auto* z = vmap["z"];
FastSet<const Value*> managed_tensors = {y, z}; FastSet<const Value*> managed_tensors = {y, z};
ManagedTensorRanges ranges(graph, managed_tensors); AliasDb alias_db(graph);
auto ranges = ManagedTensorRanges(*graph->block(), alias_db, managed_tensors);
std::vector<Node*> nodes( std::vector<Node*> nodes(
graph->block()->nodes().begin(), graph->block()->nodes().end()); graph->block()->nodes().begin(), graph->block()->nodes().end());
@ -1089,7 +1097,8 @@ TEST(ManagedTensorRanges, AliasExtendingLifetimes) {
auto* z2 = vmap["z2"]; auto* z2 = vmap["z2"];
FastSet<const Value*> managed_tensors = {y, z1, z2}; FastSet<const Value*> managed_tensors = {y, z1, z2};
ManagedTensorRanges ranges(graph, managed_tensors); AliasDb alias_db(graph);
auto ranges = ManagedTensorRanges(*graph->block(), alias_db, managed_tensors);
std::vector<Node*> nodes( std::vector<Node*> nodes(
graph->block()->nodes().begin(), graph->block()->nodes().end()); graph->block()->nodes().begin(), graph->block()->nodes().end());
@ -1135,7 +1144,8 @@ TEST(ManagedTensorRanges, LifetimeOverlap) {
auto* d = vmap["d"]; auto* d = vmap["d"];
auto* e = vmap["e"]; auto* e = vmap["e"];
ManagedTensorRanges ranges(graph, {b, c, d, e}); AliasDb alias_db(graph);
auto ranges = ManagedTensorRanges(*graph->block(), alias_db, {b, c, d, e});
const std::vector<std::pair<Value*, Value*>> overlapping_values{ const std::vector<std::pair<Value*, Value*>> overlapping_values{
{b, c}, {c, d}, {c, e}}; {b, c}, {c, d}, {c, e}};
@ -1169,7 +1179,8 @@ TEST(ManagedTensorRanges, OverlappingLifetimesContainers) {
auto* c = vmap["c"]; auto* c = vmap["c"];
auto* d = vmap["d"]; auto* d = vmap["d"];
ManagedTensorRanges ranges(graph, {b, c, d}); AliasDb alias_db(graph);
auto ranges = ManagedTensorRanges(*graph->block(), alias_db, {b, c, d});
EXPECT_TRUE(ranges.lifetimesOverlap(b, c)); EXPECT_TRUE(ranges.lifetimesOverlap(b, c));
EXPECT_TRUE(ranges.lifetimesOverlap(b, d)); EXPECT_TRUE(ranges.lifetimesOverlap(b, d));
@ -1189,7 +1200,8 @@ TEST(ManagedTensorRanges, OverlappingLifetimesOutputs) {
auto* b = vmap["b"]; auto* b = vmap["b"];
auto* output = vmap["output"]; auto* output = vmap["output"];
ManagedTensorRanges ranges(graph, {b, output}); AliasDb alias_db(graph);
auto ranges = ManagedTensorRanges(*graph->block(), alias_db, {b, output});
EXPECT_TRUE(ranges.lifetimesOverlap(b, output)); EXPECT_TRUE(ranges.lifetimesOverlap(b, output));
} }
@ -1275,7 +1287,9 @@ void testAssignStorageToManagedTensors(
} }
ASSERT_EQ(managed_tensor_values.size(), tensor_value_to_tensor.size()); ASSERT_EQ(managed_tensor_values.size(), tensor_value_to_tensor.size());
auto ranges = ManagedTensorRanges(graph, managed_tensor_values); AliasDb alias_db(graph);
auto ranges =
ManagedTensorRanges(*graph->block(), alias_db, managed_tensor_values);
auto groups = assignStorageToManagedTensors( auto groups = assignStorageToManagedTensors(
graph->block()->nodes(), ranges, tensor_value_to_tensor); graph->block()->nodes(), ranges, tensor_value_to_tensor);

View File

@ -19,6 +19,7 @@
#include <torch/csrc/jit/passes/remove_mutation.h> #include <torch/csrc/jit/passes/remove_mutation.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h> #include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <torch/csrc/jit/passes/variadic_ops.h> #include <torch/csrc/jit/passes/variadic_ops.h>
#include <torch/csrc/jit/runtime/graph_iterator.h>
#include <torch/csrc/jit/runtime/static/fusion.h> #include <torch/csrc/jit/runtime/static/fusion.h>
#include <torch/csrc/jit/runtime/static/memory_planner.h> #include <torch/csrc/jit/runtime/static/memory_planner.h>
#include <torch/csrc/jit/runtime/static/ops.h> #include <torch/csrc/jit/runtime/static/ops.h>
@ -32,6 +33,7 @@
#endif #endif
#include <iterator> #include <iterator>
#include <limits>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
@ -58,8 +60,8 @@ bool isUnsupportedOp(const NodeKind& kind) {
return kind == aten::__is__ || kind == aten::__isnot__; return kind == aten::__is__ || kind == aten::__isnot__;
} }
// graph must be frozen or canEnableStaticRuntime would return false if there's // graph must be frozen or canEnableStaticRuntime would return false
// any prim::CallMethod op left in the graph // if there's any prim::CallMethod op left in the graph
bool canEnableStaticRuntime(const std::shared_ptr<torch::jit::Graph>& graph) { bool canEnableStaticRuntime(const std::shared_ptr<torch::jit::Graph>& graph) {
// check for sub-blocks // check for sub-blocks
bool can_support = true; bool can_support = true;
@ -181,26 +183,20 @@ std::vector<Value*> valueVecFromFastSet(const FastSet<const Value*>& s) {
return result; return result;
} }
bool mayContainAlias(AliasDb& db, const Value* a, const Value* b) { bool mayContainAlias(const AliasDb& db, const Value* v1, const Value* v2) {
// AliasDb is not const-correct here, so we have to const_cast
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
return db.mayContainAlias(const_cast<Value*>(a), const_cast<Value*>(b)); return db.mayContainAlias(const_cast<Value*>(v1), const_cast<Value*>(v2));
} }
bool mayContainAlias( bool mayContainAlias(
AliasDb& db, const AliasDb& db,
const Value* a, const Value* a,
const FastSet<const Value*>& b) { const FastSet<const Value*>& b) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
return db.mayContainAlias(const_cast<Value*>(a), valueVecFromFastSet(b)); return db.mayContainAlias(const_cast<Value*>(a), valueVecFromFastSet(b));
} }
bool mayContainAlias(
AliasDb& db,
const FastSet<const Value*>& a,
const FastSet<const Value*>& b) {
return db.mayContainAlias(valueVecFromFastSet(a), valueVecFromFastSet(b));
}
void PrepareGraphForStaticModule( void PrepareGraphForStaticModule(
std::shared_ptr<torch::jit::Graph> graph, std::shared_ptr<torch::jit::Graph> graph,
const StaticModuleOptions& opts, const StaticModuleOptions& opts,
@ -248,23 +244,21 @@ std::pair<std::shared_ptr<Graph>, c10::optional<Module>> PrepareForStaticModule(
} // namespace } // namespace
void ValueGroup::init( void ValueGroup::init(const Block& block, const AliasDb& db) {
const std::shared_ptr<torch::jit::Graph>& graph,
AliasDb& db) {
external_aliases_.clear(); external_aliases_.clear();
output_aliases_.clear(); output_aliases_.clear();
// Build `external_aliases` as we look through nodes forwardly from // Build `external_aliases` as we look through nodes forwardly from
// the graph's inputs and add aliases of the inputs being created by the // the graph's inputs and add aliases of the inputs being created by the
// nodes. // nodes.
external_aliases_.insert(graph->inputs().begin(), graph->inputs().end()); external_aliases_.insert(block.inputs().begin(), block.inputs().end());
for (const auto* node : graph->nodes()) { for (const auto* node : block.nodes()) {
if (node->kind() == prim::Constant) { if (node->kind() == prim::Constant) {
for (const auto* output : node->outputs()) { for (const auto* output : node->outputs()) {
external_aliases_.insert(output); external_aliases_.insert(output);
} }
} }
} }
for (const auto* node : graph->nodes()) { for (const auto* node : block.nodes()) {
if (node->kind() == prim::Constant) { if (node->kind() == prim::Constant) {
// Constants are already in `external_aliases`. // Constants are already in `external_aliases`.
continue; continue;
@ -278,8 +272,8 @@ void ValueGroup::init(
// Build `output_aliases` as we look through nodes reversely so that we can // Build `output_aliases` as we look through nodes reversely so that we can
// start from the output values, and follow the flows backwardly from there. // start from the output values, and follow the flows backwardly from there.
output_aliases_.insert(graph->outputs().begin(), graph->outputs().end()); output_aliases_.insert(block.outputs().begin(), block.outputs().end());
for (const auto* node : graph->nodes().reverse()) { for (const auto* node : block.nodes().reverse()) {
if (node->kind() == prim::Constant) { if (node->kind() == prim::Constant) {
// Constants cannot create any aliases. // Constants cannot create any aliases.
continue; continue;
@ -317,12 +311,6 @@ bool containTensorsOnly(at::ArrayRef<Value*> values) {
}); });
} }
bool mayContainAlias(const Value* v1, const Value* v2, const AliasDb& db) {
// AliasDb is not const-correct here, so we have to const_cast
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
return db.mayContainAlias(const_cast<Value*>(v1), const_cast<Value*>(v2));
}
bool isPureFunction(const Node* node) { bool isPureFunction(const Node* node) {
auto* schema = node->maybeSchema(); auto* schema = node->maybeSchema();
return schema && return schema &&
@ -332,12 +320,12 @@ bool isPureFunction(const Node* node) {
} // namespace } // namespace
ManagedTensorRanges::ManagedTensorRanges( ManagedTensorRanges::ManagedTensorRanges(
const std::shared_ptr<Graph>& graph, Block& block,
const AliasDb& alias_db,
const FastSet<const Value*>& managed_tensor_values) { const FastSet<const Value*>& managed_tensor_values) {
AliasDb alias_db(graph); const std::vector<Node*> nodes(block.nodes().begin(), block.nodes().end());
const std::vector<Node*> nodes(graph->nodes().begin(), graph->nodes().end());
const FastSet<const Value*> graph_inputs( const FastSet<const Value*> graph_inputs(
graph->inputs().begin(), graph->inputs().end()); block.inputs().begin(), block.inputs().end());
auto isUntrackedValue = [&alias_db, &graph_inputs](const Value* value) { auto isUntrackedValue = [&alias_db, &graph_inputs](const Value* value) {
return !alias_db.isMutableType(value) || return !alias_db.isMutableType(value) ||
@ -363,7 +351,7 @@ ManagedTensorRanges::ManagedTensorRanges(
value_lifetimes_.emplace(output, Lifetime(i, i)); value_lifetimes_.emplace(output, Lifetime(i, i));
} }
} }
for (auto* graph_output : graph->outputs()) { for (auto* graph_output : block.outputs()) {
auto* lifetime = getLifetime(graph_output); auto* lifetime = getLifetime(graph_output);
if (!lifetime) { if (!lifetime) {
DCHECK(isUntrackedValue(graph_output)); DCHECK(isUntrackedValue(graph_output));
@ -376,7 +364,7 @@ ManagedTensorRanges::ManagedTensorRanges(
// has an input and output that may alias each other, set the input's // has an input and output that may alias each other, set the input's
// lifetime end to max(input.lifetime_end, output.lifetime_end). Iterate // lifetime end to max(input.lifetime_end, output.lifetime_end). Iterate
// backwards to handle chains of aliases. // backwards to handle chains of aliases.
for (const auto* node : graph->nodes().reverse()) { for (const auto* node : block.nodes().reverse()) {
if (isPureFunction(node)) { if (isPureFunction(node)) {
// If the node is a pure function, it doesn't create any aliases, // If the node is a pure function, it doesn't create any aliases,
// so we can safely skip it. // so we can safely skip it.
@ -389,7 +377,7 @@ ManagedTensorRanges::ManagedTensorRanges(
auto* input_lifetime = getLifetime(input); auto* input_lifetime = getLifetime(input);
DCHECK(input_lifetime != nullptr); DCHECK(input_lifetime != nullptr);
for (auto* output : outputs) { for (auto* output : outputs) {
if (mayContainAlias(input, output, alias_db)) { if (mayContainAlias(alias_db, input, output)) {
auto* output_lifetime = getLifetime(output); auto* output_lifetime = getLifetime(output);
DCHECK(output_lifetime != nullptr); DCHECK(output_lifetime != nullptr);
input_lifetime->end = input_lifetime->end =
@ -404,7 +392,7 @@ ManagedTensorRanges::ManagedTensorRanges(
// NOLINTNEXTLINE(cppcoreguidelines-init-variables) // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
Node* freeing_node; Node* freeing_node;
if (lifetime->end == num_nodes) { if (lifetime->end == num_nodes) {
freeing_node = graph->return_node(); freeing_node = block.return_node();
} else { } else {
freeing_node = nodes[lifetime->end]; freeing_node = nodes[lifetime->end];
} }
@ -519,15 +507,6 @@ StaticModule::StaticModule(
} }
} }
// map Value* to its SSA definition IR
FastMap<Value*, DefInfo> value_to_ssa_def;
// N inputs map to the first N entries in storage
for (const auto i : c10::irange(graph_->inputs().size())) {
Value* input = graph_->inputs()[i];
value_to_ssa_def[input] = std::make_pair(INPUT_VALUE, i);
}
{ {
size_t nodes_size = 0, constants_size = 0; size_t nodes_size = 0, constants_size = 0;
for (Node* node : graph_->nodes()) { for (Node* node : graph_->nodes()) {
@ -536,7 +515,6 @@ StaticModule::StaticModule(
constants_.reserve(constants_size); constants_.reserve(constants_size);
functions_.reserve(nodes_size); functions_.reserve(nodes_size);
nodes_.reserve(nodes_size);
} }
// Create ProcessedFunction instances first to freeze their addresses to pass // Create ProcessedFunction instances first to freeze their addresses to pass
@ -544,13 +522,89 @@ StaticModule::StaticModule(
AliasDb alias_db(graph_, /*isFrozen=*/false); AliasDb alias_db(graph_, /*isFrozen=*/false);
GRAPH_DEBUG("AliasDb: ", alias_db.toString()); GRAPH_DEBUG("AliasDb: ", alias_db.toString());
// Construct constant and function nodes // Maps each Value* in the graph to its index in the values_ array that will
for (Node* node : graph_->nodes()) { // eventually be created by StaticRuntime.
FastMap<const Value*, uint32_t> value_to_index;
prepareFunctionsAndConstants(graph_->block(), alias_db, value_to_index);
const auto constants_index_offset = 0;
const auto values_index_offset = constants_index_offset + constants().size();
value_buffer_size_ = values_index_offset;
value_buffer_size_ +=
prepareBlockInfo(graph_->block(), values_index_offset, value_to_index);
prepareProcessedNodes(graph_->block(), value_to_index, alias_db);
for (auto& block_and_info : block_infos_) {
auto& block_info = block_and_info.second;
block_info.prepare_for_memory_planner(alias_db, opts);
}
}
size_t StaticModule::prepareBlockInfo(
Block* block,
const size_t start_idx,
FastMap<const Value*, uint32_t>& value_to_index) {
block_infos_.emplace(block, BlockInfo(start_idx, *block));
const auto num_inputs = block->inputs().size();
for (const auto i : c10::irange(num_inputs)) {
value_to_index.emplace(block->inputs()[i], start_idx + i);
}
auto cur_idx = start_idx + num_inputs;
for (auto* node : block->nodes()) {
for (auto* sub_block : node->blocks()) {
cur_idx += prepareBlockInfo(sub_block, cur_idx, value_to_index);
}
if (node->kind() == prim::Constant) {
continue;
}
TORCH_CHECK(
cur_idx < (1 << 16),
"outputs offset in values table",
cur_idx,
" would overflow 2-byte index storage");
const auto num_outputs = node->outputs().size();
for (const auto i : c10::irange(num_outputs)) {
value_to_index.emplace(node->outputs()[i], cur_idx + i);
}
cur_idx += num_outputs;
}
std::vector<uint16_t> output_indices;
output_indices.reserve(block->outputs().size());
for (auto* output : block->outputs()) {
const auto output_idx = value_to_index.at(output);
TORCH_CHECK(
output_idx < (1 << 16),
"outputs offset in values table",
output_idx,
" would overflow 2-byte index storage");
output_indices.push_back(output_idx);
}
block_infos_.at(block).set_output_indices(std::move(output_indices));
return cur_idx - start_idx;
}
void StaticModule::prepareFunctionsAndConstants(
Block* block,
const AliasDb& alias_db,
FastMap<const Value*, uint32_t>& value_to_index) {
for (auto* node : block->nodes()) {
for (auto* sub_block : node->blocks()) {
prepareFunctionsAndConstants(sub_block, alias_db, value_to_index);
}
if (node->kind() == prim::Constant) { if (node->kind() == prim::Constant) {
auto* v = node->output(); auto* v = node->output();
TORCH_CHECK(v->type()->kind() != FunctionType::Kind); TORCH_CHECK(v->type()->kind() != FunctionType::Kind);
// construct SSA definition for constant nodes value_to_index.emplace(v, constants_.size());
value_to_ssa_def[v] = std::make_pair(CONSTANT_VALUE, constants_.size());
constants_.emplace_back(toIValue(v).value()); constants_.emplace_back(toIValue(v).value());
continue; continue;
} }
@ -561,66 +615,34 @@ StaticModule::StaticModule(
containTensorsOnly(node->outputs()); containTensorsOnly(node->outputs());
// new ProcessedFunction // new ProcessedFunction
functions_.emplace_back( functions_.emplace_back(
node, opts.enable_out_variant, check_outputs_for_overlap); node, opts_.enable_out_variant, check_outputs_for_overlap);
} }
}
// construct SSA definition for non-constant nodes size_t StaticModule::prepareProcessedNodes(
int node_idx = 0; Block* block,
const FastMap<const Value*, uint32_t>& value_to_index,
const AliasDb& alias_db,
size_t node_idx) {
const auto node_start = node_idx;
auto& block_info = block_infos_.at(block);
std::vector<ProcessedNode> nodes;
FastMap<Node*, bool> node_has_out_variant; FastMap<Node*, bool> node_has_out_variant;
const auto inputs_index_offset = inputs_offset(); for (auto* node : block->nodes()) {
const auto constants_index_offset = constants_offset();
const auto values_index_offset = intermediate_values_offset();
// Map node_idx to index offset in values_. Can't reserve space
// because we don't know how many non-constant nodes there are yet.
std::vector<uint32_t> node_output_idx_map;
uint32_t node_outputs_seen_so_far = 0;
for (Node* node : graph_->nodes()) {
if (node->kind() == prim::Constant) { if (node->kind() == prim::Constant) {
continue; continue;
} }
// Assign memory for the outputs
const auto outputs_offset_for_node =
node_outputs_seen_so_far + values_index_offset;
TORCH_CHECK(
outputs_offset_for_node < (1 << 16),
"outputs offset in values table",
outputs_offset_for_node,
" would overflow 2-byte index storage");
node_output_idx_map.push_back(outputs_offset_for_node);
node_outputs_seen_so_far += node->outputs().size();
}
for (Node* node : graph_->nodes()) { for (auto* sub_block : node->blocks()) {
if (node->kind() == prim::Constant) { node_idx +=
continue; prepareProcessedNodes(sub_block, value_to_index, alias_db, node_idx);
} }
ProcessedNodeInputs input_indices(node->inputs().size()); ProcessedNodeInputs input_indices(node->inputs().size());
std::vector<DefInfo> input_ssa_defs;
for (const auto input_idx : c10::irange(node->inputs().size())) { for (const auto input_idx : c10::irange(node->inputs().size())) {
Value* const input = node->inputs()[input_idx]; auto* input = node->inputs()[input_idx];
int inner_node_idx = 0; auto input_ivalue_idx = value_to_index.at(input);
int out_idx = 0;
std::tie(inner_node_idx, out_idx) = value_to_ssa_def.at(input);
unsigned int input_ivalue_idx = 0;
if (inner_node_idx == StaticModule::INPUT_VALUE) {
input_ivalue_idx = out_idx + inputs_index_offset;
} else if (inner_node_idx == StaticModule::CONSTANT_VALUE) {
input_ivalue_idx = out_idx + constants_index_offset;
} else {
DCHECK_GE(inner_node_idx, 0);
const auto global_value_idx =
node_output_idx_map[inner_node_idx] + out_idx;
if (inner_node_idx < node_output_idx_map.size() - 1) {
DCHECK_LT(global_value_idx, node_output_idx_map[inner_node_idx + 1]);
} else {
DCHECK_LT(
global_value_idx,
constants_index_offset + node_outputs_seen_so_far);
}
input_ivalue_idx = global_value_idx;
}
TORCH_CHECK( TORCH_CHECK(
input_ivalue_idx < (1 << 16), input_ivalue_idx < (1 << 16),
"input index in values table ", "input index in values table ",
@ -630,72 +652,48 @@ StaticModule::StaticModule(
} }
ProcessedFunction* fn = &functions_[node_idx]; ProcessedFunction* fn = &functions_[node_idx];
// create a new ProcessedNode // create a new ProcessedNode
// see [Check and correct bad schema alias info at runtime] const auto node_output_idx = node->outputs().empty()
bool check_outputs_for_overlap = // The index is unused if there are no outputs, so just create a
!alias_db.mayContainAlias(node->inputs(), node->outputs()) && // placeholder value.
containTensorsOnly(node->outputs()); ? std::numeric_limits<uint16_t>::max()
nodes_.emplace_back( : value_to_index.at(node->output(0));
node, fn, std::move(input_indices), node_output_idx_map[node_idx]); nodes.emplace_back(node, fn, std::move(input_indices), node_output_idx);
node_has_out_variant.emplace(node, nodes_.back().has_out_variant()); node_has_out_variant.emplace(node, nodes.back().has_out_variant());
for (const auto i : c10::irange(node->outputs().size())) { ++node_idx;
value_to_ssa_def[node->outputs()[i]] = std::make_pair(node_idx, i);
}
node_idx++;
} }
num_intermediate_values_ = std::accumulate( block_info.set_nodes(std::move(nodes), node_has_out_variant);
nodes_.begin(), block_info.init_value_group(alias_db);
nodes_.end(),
0,
[](uint32_t sum, const ProcessedNode& pnode) {
return sum + pnode.num_outputs();
});
for (auto& pnode : nodes_) { return node_idx - node_start;
if (pnode.num_outputs() == 1 &&
isOptimizableContainerType(pnode.node(), node_has_out_variant)) {
node_is_optimizable_container_type_.emplace(pnode.node());
}
}
output_indices_.reserve(graph_->outputs().size());
for (auto output : graph_->outputs()) {
int node_idx = 0;
int out_idx = 0;
std::tie(node_idx, out_idx) = value_to_ssa_def[output];
uint32_t output_index = 0;
if (node_idx == StaticModule::INPUT_VALUE) {
output_index = out_idx + inputs_index_offset;
} else if (node_idx == StaticModule::CONSTANT_VALUE) {
output_index = constants_index_offset + out_idx;
} else {
output_index = nodes_[node_idx].output_ivalue_index(out_idx);
}
TORCH_CHECK(
output_index < (1 << 16),
"output index ",
output_index,
" would overflow 2-byte index storage");
output_indices_.emplace_back(output_index);
}
// Prepare for memory planning
value_group_.init(graph_, alias_db);
GRAPH_DEBUG(value_group_.toString());
prepareForMemoryPlanner();
} }
void StaticModule::prepareForMemoryPlanner() { void BlockInfo::set_nodes(
if (!opts_.enable_out_variant) { std::vector<ProcessedNode> nodes,
const FastMap<Node*, bool>& node_has_out_variant) {
nodes_ = std::move(nodes);
for (auto& node : nodes_) {
if (node.num_outputs() == 1 &&
isOptimizableContainerType(node.node(), node_has_out_variant)) {
node_is_optimizable_container_type_.emplace(node.node());
}
}
}
void BlockInfo::prepare_for_memory_planner(
const AliasDb& alias_db,
const StaticModuleOptions& opts) {
if (!opts.enable_out_variant) {
return; return;
} }
// Never manage graph outputs so that we can do std::move(output_ivalue). // Never manage graph outputs so that we can do std::move(output_ivalue).
// This does not affect performance if the graph returns a collection object. // This does not affect performance if the graph returns a collection object.
FastSet<const Value*> graph_output_values( FastSet<const Value*> graph_output_values(
graph_->outputs().begin(), graph_->outputs().end()); block_.outputs().begin(), block_.outputs().end());
// collect register indices of outputs of ops with out variant // collect register indices of outputs of ops with out variant
for (ProcessedNode& pnode : nodes_) { for (ProcessedNode& pnode : nodes_) {
@ -707,7 +705,7 @@ void StaticModule::prepareForMemoryPlanner() {
const Value* out_v = outputs[i]; const Value* out_v = outputs[i];
// Types are stored in the underlying TorchScript IR // Types are stored in the underlying TorchScript IR
bool is_tensor_type = out_v->type()->castRaw<TensorType>(); bool is_tensor_type = out_v->type()->castRaw<TensorType>();
if (opts_.manage_output_tensors && is_tensor_type && if (opts.manage_output_tensors && is_tensor_type &&
graph_output_values.find(out_v) == graph_output_values.end() && graph_output_values.find(out_v) == graph_output_values.end() &&
value_group_.isOutputAlias(out_v)) { value_group_.isOutputAlias(out_v)) {
managed_output_tensor_values_.insert(out_v); managed_output_tensor_values_.insert(out_v);
@ -718,7 +716,7 @@ void StaticModule::prepareForMemoryPlanner() {
} }
if (is_tensor_type) { if (is_tensor_type) {
managed_tensor_values_.insert(out_v); managed_tensor_values_.insert(out_v);
} else if (is_optimizable_container_type(pnode.node())) { } else if (node_is_optimizable_container_type(pnode.node())) {
// We "leak" certain container types because their allocations // We "leak" certain container types because their allocations
// take a long time // take a long time
leaked_values_.insert(out_v); leaked_values_.insert(out_v);
@ -726,7 +724,7 @@ void StaticModule::prepareForMemoryPlanner() {
} }
} }
for (const Value* output : graph_->outputs()) { for (const Value* output : block_.outputs()) {
managed_tensor_values_.erase(output); managed_tensor_values_.erase(output);
} }
GRAPH_DEBUG("managed_tensor_values: ", dumpValueSet(managed_tensor_values_)); GRAPH_DEBUG("managed_tensor_values: ", dumpValueSet(managed_tensor_values_));
@ -734,7 +732,8 @@ void StaticModule::prepareForMemoryPlanner() {
"managed_output_tensor_values_: ", "managed_output_tensor_values_: ",
dumpValueSet(managed_output_tensor_values_)); dumpValueSet(managed_output_tensor_values_));
managed_tensor_ranges_ = ManagedTensorRanges(graph_, managed_tensor_values_); managed_tensor_ranges_ =
ManagedTensorRanges(block_, alias_db, managed_tensor_values_);
} }
const StaticModuleOptions& StaticModule::opts() const { const StaticModuleOptions& StaticModule::opts() const {
@ -757,9 +756,12 @@ StaticRuntime& StaticModule::runtime() {
} }
Node* StaticModule::findNodeWithKindForTesting(const std::string& kind) const { Node* StaticModule::findNodeWithKindForTesting(const std::string& kind) const {
for (auto& pnode : nodes()) { for (auto& block_and_info : block_infos_) {
if (pnode.node()->kind().toQualString() == kind) { auto& block_info = block_and_info.second;
return pnode.node(); for (auto& pnode : block_info.nodes()) {
if (pnode.node()->kind().toQualString() == kind) {
return pnode.node();
}
} }
} }
return nullptr; return nullptr;
@ -777,41 +779,64 @@ c10::IValue StaticModule::operator()(
return runtime()(std::move(args), kwargs); return runtime()(std::move(args), kwargs);
} }
StaticRuntime::StaticRuntime(const StaticModule& sm) BlockRunner::BlockRunner(
const StaticModule& sm,
std::vector<IValue>& values,
Block* block,
bool is_root_block)
: static_module_(sm), : static_module_(sm),
first_input_is_self_(static_module_.first_input_is_self()), block_info_(static_module_.block_info(block)),
manage_output_tensors_enabled_(sm.opts().manage_output_tensors), is_root_block_(is_root_block),
nodes_(sm.nodes()) { first_input_is_self_(
values_.resize(sm.total_num_values()); is_root_block_ && static_module_.first_input_is_self()),
const auto constants_index_offset = sm.constants_offset(); inputs_begin_(block_info_.block_inputs_idx()),
const auto constants_begin_it = values_.begin() + constants_index_offset; // TODO(T108633124): Turn on manage output tensors for sub-blocks.
const auto constants_end_it = constants_begin_it + sm.constants().size(); manage_output_tensors_enabled_(
std::copy(sm.constants().begin(), sm.constants().end(), constants_begin_it); is_root_block_ && sm.opts().manage_output_tensors),
values_(values),
for (const auto idx : c10::irange(sm.nodes().size())) { nodes_(block_info_.nodes()) {
auto& n = nodes_[idx]; for (auto& n : nodes_) {
n.set_values(values_.data()); n.set_values(values_.data());
} }
// TODO: can we convert outputs_ to store indices? for (auto index : block_info_.block_output_indices()) {
for (auto index : sm.output_indices()) {
outputs_.emplace_back(&values_[index]); outputs_.emplace_back(&values_[index]);
} }
for (auto& pnode : nodes_) {
auto* node = pnode.node();
auto blocks = node->blocks();
const auto num_blocks = blocks.size();
if (num_blocks == 0) {
continue;
}
DCHECK(node->kind() == prim::If || node->kind() == prim::Loop);
auto block_runners = std::make_unique<std::vector<BlockRunner>>();
block_runners->reserve(num_blocks);
for (auto* b : blocks) {
block_runners->emplace_back(sm, values, b);
}
pnode.set_block_runners(std::move(block_runners));
}
} }
StaticRuntime::~StaticRuntime() = default; // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
BlockRunner::BlockRunner(BlockRunner&&) noexcept = default;
void StaticRuntime::set_arg(const size_t idx, std::vector<IValue>&& args) { BlockRunner::~BlockRunner() = default;
void BlockRunner::set_arg(const size_t idx, std::vector<IValue>&& args) {
DCHECK(idx < args.size()); DCHECK(idx < args.size());
Input(idx + first_input_is_self_) = std::move(args[idx]); Input(idx + first_input_is_self_) = std::move(args[idx]);
} }
void StaticRuntime::set_arg(const size_t idx, const std::vector<IValue>& args) { void BlockRunner::set_arg(const size_t idx, const std::vector<IValue>& args) {
DCHECK(idx < args.size()); DCHECK(idx < args.size());
Input(idx + first_input_is_self_) = args[idx]; Input(idx + first_input_is_self_) = args[idx];
} }
void StaticRuntime::set_arg(const size_t idx, const IValue& arg) { void BlockRunner::set_arg(const size_t idx, const IValue& arg) {
Input(idx + first_input_is_self_) = arg; Input(idx + first_input_is_self_) = arg;
} }
@ -827,24 +852,21 @@ void check_type(const Argument& schema_arg, const IValue& arg) {
} // namespace } // namespace
template <typename IValueList> template <typename IValueList>
void StaticRuntime::set_inputs( void BlockRunner::set_inputs(
IValueList&& args, IValueList&& args,
const std::unordered_map<std::string, c10::IValue>& kwargs) { const std::unordered_map<std::string, c10::IValue>& kwargs) {
const auto total_num_inputs = const auto total_num_inputs =
args.size() + kwargs.size() + first_input_is_self_; args.size() + kwargs.size() + first_input_is_self_;
TORCH_CHECK(total_num_inputs == static_module_.num_inputs()); TORCH_CHECK(total_num_inputs == block_info_.num_inputs());
const auto& schema = static_module_.schema(); const auto& schema = static_module_.schema();
if (first_input_is_self_) { if (first_input_is_self_) {
Input(0) = static_module_.module()._ivalue(); Input(0) = static_module_.module()._ivalue();
} }
if (C10_UNLIKELY(!schema)) { if (!is_root_block_ || C10_UNLIKELY(!schema)) {
TORCH_CHECK( TORCH_CHECK(
kwargs.empty(), kwargs.empty(), "Schema is not available, but BlockRunner got kwargs.");
"Schema is not available, but StaticRuntime got kwargs. "
"Consider creating the Static Runtime instance "
"with StaticModule(const torch::jit::Module& m) instead.");
for (size_t i = 0; i < args.size(); ++i) { for (size_t i = 0; i < args.size(); ++i) {
set_arg(i, std::forward<IValueList>(args)); set_arg(i, std::forward<IValueList>(args));
} }
@ -887,15 +909,11 @@ void StaticRuntime::set_inputs(
args.size() + consumed_kwargs == schema_args.size() - 1); args.size() + consumed_kwargs == schema_args.size() - 1);
} }
void StaticRuntime::create_memory_planner() { void BlockRunner::create_memory_planner() {
if (!planner_) { if (!planner_) {
planner_ = std::make_unique<MemoryPlanner>( planner_ = std::make_unique<MemoryPlanner>(
this, this,
static_module_.value_group(), block_info_,
static_module_.managed_tensor_values(),
static_module_.managed_output_tensor_values(),
static_module_.leaked_values(),
static_module_.managed_tensor_ranges(),
static_module_.opts().enable_out_variant, static_module_.opts().enable_out_variant,
manage_output_tensors_enabled_, manage_output_tensors_enabled_,
static_module_.opts().optimize_memory); static_module_.opts().optimize_memory);
@ -924,7 +942,7 @@ void destroyNodeOutputs(ProcessedNode& p_node) {
} // namespace } // namespace
void StaticRuntime::clean_up_intermediate_ivalues() noexcept { void BlockRunner::clean_up_intermediate_ivalues() noexcept {
// We have to iterate in reverse order here due to borrowed // We have to iterate in reverse order here due to borrowed
// IValues - we don't want to destroy a value until all of its // IValues - we don't want to destroy a value until all of its
// borrows are cleaned up! // borrows are cleaned up!
@ -933,7 +951,7 @@ void StaticRuntime::clean_up_intermediate_ivalues() noexcept {
} }
} }
void StaticRuntime::resetMemory() noexcept { void BlockRunner::resetMemory() noexcept {
planner_.reset(); planner_.reset();
// We must clean up intermediate values before inputs in case // We must clean up intermediate values before inputs in case
// there are borrowed inputs and static runtime owns the only // there are borrowed inputs and static runtime owns the only
@ -942,7 +960,7 @@ void StaticRuntime::resetMemory() noexcept {
clean_up_input_ivalues(); clean_up_input_ivalues();
} }
c10::IValue StaticRuntime::move_outputs_to_tuple(uint32_t num_outputs) { c10::IValue BlockRunner::move_outputs_to_tuple(uint32_t num_outputs) {
switch (num_outputs) { switch (num_outputs) {
case 1: case 1:
return c10::ivalue::Tuple::create(IValue(std::move(*outputs_[0]))); return c10::ivalue::Tuple::create(IValue(std::move(*outputs_[0])));
@ -1032,7 +1050,7 @@ c10::IValue StaticRuntime::move_outputs_to_tuple(uint32_t num_outputs) {
/// buffer) fails. There is still a corner case that fails with the added flag. /// buffer) fails. There is still a corner case that fails with the added flag.
/// If a resize is triggered at the same time as the op creating an alias at the /// If a resize is triggered at the same time as the op creating an alias at the
/// same time, the current checks would fail to detect the alias. /// same time, the current checks would fail to detect the alias.
void StaticRuntime::verify_and_correct_memory_overlap(ProcessedNode& n) { void BlockRunner::verify_and_correct_memory_overlap(ProcessedNode& n) {
// The slow check can be removed once the internal/output buffers are merged // The slow check can be removed once the internal/output buffers are merged
if (C10_UNLIKELY(n.check_outputs_for_memory_overlap())) { if (C10_UNLIKELY(n.check_outputs_for_memory_overlap())) {
if (C10_UNLIKELY(!planner_)) { if (C10_UNLIKELY(!planner_)) {
@ -1065,7 +1083,7 @@ void StaticRuntime::verify_and_correct_memory_overlap(ProcessedNode& n) {
} }
} }
bool StaticRuntime::fast_check_and_correct_overlap_with( bool BlockRunner::fast_check_and_correct_overlap_with(
ProcessedNode& n, ProcessedNode& n,
c10::IValue& tensor_ival) { c10::IValue& tensor_ival) {
auto& tensor = tensor_ival.toTensor(); auto& tensor = tensor_ival.toTensor();
@ -1078,38 +1096,38 @@ bool StaticRuntime::fast_check_and_correct_overlap_with(
return false; return false;
} }
StaticRuntime::Deallocator::~Deallocator() { BlockRunner::Deallocator::~Deallocator() {
// Assume cleanup cannot throw. // Assume cleanup cannot throw.
cleanupImpl(); cleanupImpl();
#ifndef NDEBUG #ifndef NDEBUG
runtime_.check_for_memory_leak(/*output_returned*/ false); block_runner_.check_for_memory_leak(/*output_returned*/ false);
#endif #endif
} }
void StaticRuntime::Deallocator::cleanupImpl() { void BlockRunner::Deallocator::cleanupImpl() {
// MemoryPlanner is created after the first invocation of `run()`. This // MemoryPlanner is created after the first invocation of `run()`. This
// is done intentionally because MemoryPlanner uses `Tensor` sizes of // is done intentionally because MemoryPlanner uses `Tensor` sizes of
// the previous `run()` for memory planning of subsequent runs // the previous `run()` for memory planning of subsequent runs
if (C10_LIKELY(finished_)) { if (C10_LIKELY(finished_)) {
runtime_.create_memory_planner(); block_runner_.create_memory_planner();
} }
if (C10_LIKELY(runtime_.planner_)) { if (C10_LIKELY(block_runner_.planner_)) {
runtime_.planner_->deallocate(); block_runner_.planner_->deallocate();
} else { } else {
// This is the first run, and it didn't finish, so we can't use a // This is the first run, and it didn't finish, so we can't use a
// `MemoryPlanner` to deallocate stuff. Just reset everything mannually. // `MemoryPlanner` to deallocate stuff. Just reset everything mannually.
runtime_.resetMemory(); block_runner_.resetMemory();
} }
// clean up owning refs of input tensors // clean up owning refs of input tensors
runtime_.clean_up_input_ivalues(); block_runner_.clean_up_input_ivalues();
if (C10_UNLIKELY(!finished_)) { if (C10_UNLIKELY(!finished_)) {
runtime_.deallocateOutputTensors(); block_runner_.deallocateOutputTensors();
} }
} }
template <typename IValueList> template <typename IValueList>
c10::IValue StaticRuntime::run_impl( c10::IValue BlockRunner::run_impl(
IValueList&& args, IValueList&& args,
const KeywordArgs& kwargs) { const KeywordArgs& kwargs) {
// We assume inference workloads, so we do not need // We assume inference workloads, so we do not need
@ -1138,8 +1156,8 @@ c10::IValue StaticRuntime::run_impl(
} }
// no need to keep references of outputs in static runtime anymore // no need to keep references of outputs in static runtime anymore
if (static_module_.num_outputs() > 1) { if (block_info_.num_outputs() > 1) {
return move_outputs_to_tuple(static_module_.num_outputs()); return move_outputs_to_tuple(block_info_.num_outputs());
} }
DCHECK(check_for_memory_leak(/*output_returned*/ false)); DCHECK(check_for_memory_leak(/*output_returned*/ false));
@ -1149,7 +1167,7 @@ c10::IValue StaticRuntime::run_impl(
} }
template <typename IValueList> template <typename IValueList>
c10::IValue StaticRuntime::run_impl_record_functions( c10::IValue BlockRunner::run_impl_record_functions(
IValueList&& args, IValueList&& args,
const KeywordArgs& kwargs) { const KeywordArgs& kwargs) {
bool pre_sampled = false; bool pre_sampled = false;
@ -1168,7 +1186,7 @@ c10::IValue StaticRuntime::run_impl_record_functions(
return run_impl(std::forward<IValueList>(args), kwargs); return run_impl(std::forward<IValueList>(args), kwargs);
} }
c10::IValue StaticRuntime::operator()( c10::IValue BlockRunner::operator()(
const std::vector<c10::IValue>& args, const std::vector<c10::IValue>& args,
const KeywordArgs& kwargs) { const KeywordArgs& kwargs) {
#ifdef PYTORCH_DISABLE_NET_PROFILING #ifdef PYTORCH_DISABLE_NET_PROFILING
@ -1178,7 +1196,7 @@ c10::IValue StaticRuntime::operator()(
#endif #endif
} }
c10::IValue StaticRuntime::operator()( c10::IValue BlockRunner::operator()(
std::vector<c10::IValue>&& args, std::vector<c10::IValue>&& args,
const KeywordArgs& kwargs) { const KeywordArgs& kwargs) {
#ifdef PYTORCH_DISABLE_NET_PROFILING #ifdef PYTORCH_DISABLE_NET_PROFILING
@ -1205,7 +1223,7 @@ std::string generate_latency_json(const std::string& label, double millis) {
} // namespace } // namespace
void StaticRuntime::benchmark( void BlockRunner::benchmark(
const std::vector<std::vector<c10::IValue>>& args_list, const std::vector<std::vector<c10::IValue>>& args_list,
const std::vector<KeywordArgs>& kwargs_list, const std::vector<KeywordArgs>& kwargs_list,
const int warmup_runs, const int warmup_runs,
@ -1267,7 +1285,7 @@ void StaticRuntime::benchmark(
} }
std::cout << std::setw(15) << results.total_time << " ms. in Total" std::cout << std::setw(15) << results.total_time << " ms. in Total"
<< std::endl; << std::endl;
std::cout << "StaticRuntime setup time: " << results.setup_time << " ms" std::cout << "BlockRunner setup time: " << results.setup_time << " ms"
<< std::endl; << std::endl;
std::cout << "Memory allocation time: " << results.memory_alloc_time std::cout << "Memory allocation time: " << results.memory_alloc_time
<< " ms\n"; << " ms\n";
@ -1312,7 +1330,7 @@ void StaticRuntime::benchmark(
#endif #endif
} }
float StaticRuntime::benchmark_model( float BlockRunner::benchmark_model(
const std::vector<std::vector<c10::IValue>>& args_list, const std::vector<std::vector<c10::IValue>>& args_list,
const std::vector<KeywordArgs>& kwargs_list, const std::vector<KeywordArgs>& kwargs_list,
const int warmup_runs, const int warmup_runs,
@ -1396,7 +1414,7 @@ void display_pnode_info(const ProcessedNode& pnode) {
} }
} }
void StaticRuntime::display_nodes( void BlockRunner::display_nodes(
const std::vector<c10::IValue>& args, const std::vector<c10::IValue>& args,
const KeywordArgs& kwargs) { const KeywordArgs& kwargs) {
c10::InferenceMode mode; c10::InferenceMode mode;
@ -1415,7 +1433,7 @@ void StaticRuntime::display_nodes(
on_exit.setFinished(); on_exit.setFinished();
} }
StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops( BlockRunner::IndividualMetrics BlockRunner::benchmark_individual_ops(
const std::vector<std::vector<c10::IValue>>& args_list, const std::vector<std::vector<c10::IValue>>& args_list,
const std::vector<KeywordArgs>& kwargs_list, const std::vector<KeywordArgs>& kwargs_list,
const int warmup_runs, const int warmup_runs,
@ -1543,10 +1561,16 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
return results; return results;
} }
bool StaticRuntime::check_for_memory_leak(bool output_returned) { bool BlockRunner::check_for_memory_leak(
bool output_returned,
bool recurse_on_sub_blocks) {
// check for inputs // check for inputs
for (const auto i : c10::irange(static_module_.num_inputs())) { for (const auto i : c10::irange(block_info_.num_inputs())) {
TORCH_CHECK(values_[i].isNone(), "Input ", i, " was not cleaned up"); TORCH_CHECK(
values_[i + block_info_.block_inputs_idx()].isNone(),
"Input ",
i,
" was not cleaned up");
} }
FastSet<const IValue*> output_ivalues(outputs_.begin(), outputs_.end()); FastSet<const IValue*> output_ivalues(outputs_.begin(), outputs_.end());
for (const auto n : c10::irange(nodes_.size())) { for (const auto n : c10::irange(nodes_.size())) {
@ -1561,7 +1585,8 @@ bool StaticRuntime::check_for_memory_leak(bool output_returned) {
(isManagedOutputTensor(*ival) || isManagedOutputTensorValue(val))) { (isManagedOutputTensor(*ival) || isManagedOutputTensorValue(val))) {
// `ival` contains a managed output tensor that the runtime doesn't // `ival` contains a managed output tensor that the runtime doesn't
// reclaim at the end of an iteration, but the client does so // reclaim at the end of an iteration, but the client does so
// by explicitly calling `StaticRuntime::deallocateOutputTensors`. // by explicitly calling
// `BlockRunner::deallocateOutputTensors`.
continue; continue;
} }
const std::string error_msg = "Output " + c10::to_string(i) + ", %" + const std::string error_msg = "Output " + c10::to_string(i) + ", %" +
@ -1573,7 +1598,8 @@ bool StaticRuntime::check_for_memory_leak(bool output_returned) {
if (!ival->isNone()) { if (!ival->isNone()) {
TORCH_CHECK( TORCH_CHECK(
ival->isTensor() || ival->isTensor() ||
static_module_.is_optimizable_container_type(pnode.node()) || block_info_.node_is_optimizable_container_type(
pnode.node()) ||
doesNotHeapAllocateWhenStoredInIValue(*val->type()), doesNotHeapAllocateWhenStoredInIValue(*val->type()),
error_msg); error_msg);
if (ival->isTensor()) { if (ival->isTensor()) {
@ -1595,12 +1621,20 @@ bool StaticRuntime::check_for_memory_leak(bool output_returned) {
} }
} }
} }
auto* block_runners = pnode.block_runners();
if (recurse_on_sub_blocks && block_runners) {
for (auto& block_runner : *block_runners) {
block_runner.check_for_memory_leak(
output_returned, recurse_on_sub_blocks);
}
}
} }
VLOG(1) << "Finished checking for memory leak"; VLOG(1) << "Finished checking for memory leak";
return true; return true;
} }
void StaticRuntime::deallocateOutputTensors() { void BlockRunner::deallocateOutputTensors() {
if (!static_module_.opts().manage_output_tensors) { if (!static_module_.opts().manage_output_tensors) {
TORCH_CHECK( TORCH_CHECK(
!planner_ || planner_->numOutputBufferBytes() == 0, !planner_ || planner_->numOutputBufferBytes() == 0,
@ -1613,7 +1647,7 @@ void StaticRuntime::deallocateOutputTensors() {
} }
} }
bool StaticRuntime::checkOutputTensorMemoryLeaks() { bool BlockRunner::checkOutputTensorMemoryLeaks() {
if (!static_module_.opts().manage_output_tensors || !planner_) { if (!static_module_.opts().manage_output_tensors || !planner_) {
return true; return true;
} }
@ -1639,21 +1673,21 @@ bool StaticRuntime::checkOutputTensorMemoryLeaks() {
return true; return true;
} }
bool StaticRuntime::isManagedOutputTensor(const IValue& ivalue) const { bool BlockRunner::isManagedOutputTensor(const IValue& ivalue) const {
return planner_ && planner_->isManagedOutputTensor(ivalue); return planner_ && planner_->isManagedOutputTensor(ivalue);
} }
bool StaticRuntime::isManagedOutputTensorValue(const Value* value) const { bool BlockRunner::isManagedOutputTensorValue(const Value* value) const {
// It's possible that manage_output_tensors_ was disabled after initializing // It's possible that manage_output_tensors_ was disabled after initializing
// managed_output_tensor_values, so we have to check that flag here. // managed_output_tensor_values, so we have to check that flag here.
if (!planner_ || !manage_output_tensors_enabled_) { if (!planner_ || !manage_output_tensors_enabled_) {
return false; return false;
} }
const auto& managed_outputs = static_module_.managed_output_tensor_values(); const auto& managed_outputs = block_info_.managed_output_tensor_values();
return managed_outputs.find(value) != managed_outputs.end(); return managed_outputs.find(value) != managed_outputs.end();
} }
void StaticRuntime::disableManageOutputTensors() { void BlockRunner::disableManageOutputTensors() {
if (!manage_output_tensors_enabled_) { if (!manage_output_tensors_enabled_) {
return; return;
} }
@ -1915,5 +1949,50 @@ void ProcessedNode::verify_and_correct_memory_overlap() {
} }
} }
StaticRuntime::StaticRuntime(const StaticModule& sm) {
values_.resize(sm.value_buffer_size());
std::copy(sm.constants().begin(), sm.constants().end(), values_.begin());
block_ = std::make_unique<BlockRunner>(
sm, values_, sm.root_block(), /*is_root_block*/ true);
;
}
c10::IValue StaticRuntime::operator()(
const std::vector<c10::IValue>& args,
const KeywordArgs& kwargs) {
return (*block_)(args, kwargs);
}
c10::IValue StaticRuntime::operator()(
std::vector<c10::IValue>&& args,
const KeywordArgs& kwargs) {
return (*block_)(std::move(args), kwargs);
}
bool StaticRuntime::check_for_memory_leak(bool output_returned) {
return block_->check_for_memory_leak(
output_returned, /* recurse_on_sub_blocks */ true);
}
bool StaticRuntime::checkOutputTensorMemoryLeaks() {
return block_->checkOutputTensorMemoryLeaks();
}
void StaticRuntime::deallocateOutputTensors() {
block_->deallocateOutputTensors();
}
bool StaticRuntime::isManagedOutputTensor(const IValue& ivalue) const {
return block_->isManagedOutputTensor(ivalue);
}
void StaticRuntime::disableManageOutputTensors() {
block_->disableManageOutputTensors();
}
const MemoryPlanner* StaticRuntime::get_memory_planner() const {
return block_->get_memory_planner();
}
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -13,6 +13,7 @@
#include <torch/csrc/jit/passes/freeze_module.h> #include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/passes/inliner.h> #include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/runtime/static/ProcessedNodeInputs.h> #include <torch/csrc/jit/runtime/static/ProcessedNodeInputs.h>
#include <limits>
#ifdef FBCODE_CAFFE2 #ifdef FBCODE_CAFFE2
#include <folly/container/F14Map.h> #include <folly/container/F14Map.h>
@ -82,7 +83,7 @@ TORCH_API inline bool borrowsOutputs(c10::Symbol kind) {
class ValueGroup { class ValueGroup {
public: public:
explicit ValueGroup() = default; explicit ValueGroup() = default;
void init(const std::shared_ptr<torch::jit::Graph>& graph, AliasDb& db); void init(const Block& block, const AliasDb& db);
bool isExternalAlias(const Value* value) const { bool isExternalAlias(const Value* value) const {
return external_aliases_.find(value) != external_aliases_.end(); return external_aliases_.find(value) != external_aliases_.end();
@ -112,7 +113,8 @@ class TORCH_API ManagedTensorRanges {
public: public:
ManagedTensorRanges() = default; ManagedTensorRanges() = default;
ManagedTensorRanges( ManagedTensorRanges(
const std::shared_ptr<Graph>& graph, Block& block,
const AliasDb& alias_db,
const FastSet<const Value*>& managed_tensor_values); const FastSet<const Value*>& managed_tensor_values);
// If true, then this node is the last use of at least one // If true, then this node is the last use of at least one
@ -213,11 +215,122 @@ struct TORCH_API StaticModuleOptions {
/// pool.push(runtime); /// pool.push(runtime);
/// @endcode /// @endcode
/// ///
class MemoryPlanner; class MemoryPlanner;
class ProcessedFunction; class ProcessedFunction;
class ProcessedNode; class ProcessedNode;
class StaticRuntime; class StaticRuntime;
// A `BlockInfo` instance stores all of the shared state that each
// `BlockRunner` will need to access. Most of this information is
// read-only and shared between threads.
// - Each `BlockInfo` corresponds to one block in the graph.
// - Each `BlockInfo` may be used by multiple block runners (when there are many
// threads).
// - All of the `BlockInfo`s are stored in a vector in the `StaticModule` and
// are initialized during `StaticModule` construction.
// - Most of the information stored is used to initialize the block's memory
// planner.
class BlockInfo {
public:
BlockInfo(uint32_t input_idx, Block& block)
: input_idx_(input_idx), block_(block) {}
void set_nodes(
std::vector<ProcessedNode> nodes,
const FastMap<Node*, bool>& node_has_out_variant);
const std::vector<ProcessedNode>& nodes() const {
return nodes_;
}
size_t num_nodes() const {
return nodes_.size();
}
size_t num_inputs() const {
return block_.inputs().size();
}
size_t num_outputs() const {
return block_.outputs().size();
}
graph_node_list node_ptrs() const {
return block_.nodes();
}
void set_output_indices(std::vector<uint16_t> indices) {
output_indices_ = std::move(indices);
}
const std::vector<uint16_t>& block_output_indices() const {
return output_indices_;
}
auto block_inputs_idx() const {
return input_idx_;
}
bool node_is_optimizable_container_type(const Node* node) const {
return node_is_optimizable_container_type_.find(node) !=
node_is_optimizable_container_type_.end();
}
bool value_is_managed_tensor(const Value* value) const {
return managed_tensor_values_.find(value) != managed_tensor_values_.end();
}
bool value_is_leaked_container(const Value* value) const {
return leaked_values_.find(value) != leaked_values_.end();
}
const ValueGroup& value_group() const {
return value_group_;
}
const ManagedTensorRanges& managed_tensor_ranges() const {
return managed_tensor_ranges_;
}
void init_value_group(const AliasDb& alias_db) {
value_group_.init(block_, alias_db);
}
void prepare_for_memory_planner(
const AliasDb& alias_db,
const StaticModuleOptions& opt);
const auto& managed_output_tensor_values() const {
return managed_output_tensor_values_;
}
const auto& managed_tensor_values() const {
return managed_tensor_values_;
}
const auto& leaked_values() const {
return leaked_values_;
}
private:
std::vector<ProcessedNode> nodes_;
ValueGroup value_group_;
FastSet<const Node*> node_is_optimizable_container_type_;
FastSet<const Value*> managed_tensor_values_;
FastSet<const Value*> managed_output_tensor_values_;
FastSet<const Value*> leaked_values_;
ManagedTensorRanges managed_tensor_ranges_{};
// The index of this block's inputs in the shared values_ array.
const uint16_t input_idx_;
// The indices of this block's outputs in the shared values_ array.
std::vector<uint16_t> output_indices_;
Block& block_;
};
class TORCH_API StaticModule { class TORCH_API StaticModule {
public: public:
explicit StaticModule( explicit StaticModule(
@ -231,23 +344,12 @@ class TORCH_API StaticModule {
const StaticModuleOptions& opts = StaticModuleOptions(), const StaticModuleOptions& opts = StaticModuleOptions(),
std::vector<IValue> sample_inputs = {}); std::vector<IValue> sample_inputs = {});
typedef enum {
CONSTANT_VALUE = -2, // VALUE nodes defined by prim::Constant
INPUT_VALUE = -1, // VALUE nodes representing graph inputs
} VALUE_KIND;
private: private:
explicit StaticModule( explicit StaticModule(
std::pair<std::shared_ptr<torch::jit::Graph>, c10::optional<Module>> std::pair<std::shared_ptr<torch::jit::Graph>, c10::optional<Module>>
graph_and_module, graph_and_module,
const StaticModuleOptions& opts); const StaticModuleOptions& opts);
// for <kind, idx>
// if kind == CONSTANT_VALUE: map to constants_[idx]
// if kind == INPUT_VALUE: map to inputs_[idx]
// otherwise: map to nodes_[kind].outputs()[idx]
using DefInfo = std::pair<int, int>;
public: public:
using KeywordArgs = std::unordered_map<std::string, c10::IValue>; using KeywordArgs = std::unordered_map<std::string, c10::IValue>;
c10::IValue operator()( c10::IValue operator()(
@ -268,10 +370,6 @@ class TORCH_API StaticModule {
const StaticModuleOptions& opts() const; const StaticModuleOptions& opts() const;
const ValueGroup& valueGroup() const {
return value_group_;
}
size_t num_inputs() const; size_t num_inputs() const;
size_t num_outputs() const; size_t num_outputs() const;
@ -295,74 +393,69 @@ class TORCH_API StaticModule {
return constants_; return constants_;
} }
const BlockInfo& block_info(Block* block) const {
return block_infos_.at(block);
}
Block* root_block() const {
return graph_->block();
}
private: private:
friend class StaticRuntime; friend class StaticRuntime;
friend class BlockRunner;
// Our nodes don't have their inputs & outputs initialized; don't
// let anybody but StaticRuntime and tests get them.
const std::vector<ProcessedNode>& nodes() const {
return nodes_;
}
public: public:
auto num_nodes() const { auto num_nodes() const {
return nodes_.size(); return std::accumulate(
block_infos_.begin(),
block_infos_.end(),
0,
[](size_t sum, const auto& block_and_info) {
auto& block_info = block_and_info.second;
return sum + block_info.num_nodes();
});
} }
C10_NODISCARD Node* findNodeWithKindForTesting(const std::string& kind) const; C10_NODISCARD Node* findNodeWithKindForTesting(const std::string& kind) const;
graph_node_list node_ptrs() const {
return graph_->nodes();
}
bool is_optimizable_container_type(const Node* n) const {
auto it = node_is_optimizable_container_type_.find(n);
return it != node_is_optimizable_container_type_.end();
}
const c10::optional<c10::FunctionSchema>& schema() const { const c10::optional<c10::FunctionSchema>& schema() const {
return schema_; return schema_;
} }
const ValueGroup& value_group() const {
return value_group_;
}
const FastSet<const Value*>& managed_tensor_values() const {
return managed_tensor_values_;
}
const FastSet<const Value*>& managed_output_tensor_values() const {
return managed_output_tensor_values_;
}
const FastSet<const Value*>& leaked_values() const {
return leaked_values_;
}
const ManagedTensorRanges& managed_tensor_ranges() const {
return managed_tensor_ranges_;
}
bool first_input_is_self() const { bool first_input_is_self() const {
return module_.has_value(); return module_.has_value();
} }
size_t inputs_offset() const {
return 0;
}
size_t constants_offset() const {
return inputs_offset() + num_inputs();
}
size_t intermediate_values_offset() const {
return constants_offset() + num_constants();
}
StaticRuntime& runtime(); StaticRuntime& runtime();
// See [Shared values array]
size_t value_buffer_size() const {
return value_buffer_size_;
}
private: private:
// Recursively prepares the BlockInfo array.
// - Populates `value_to_index` with the indices of each intermediate value
// - Returns the number of Value* processed, including sub-blocks.
size_t prepareBlockInfo(
Block* block,
const size_t start_idx,
FastMap<const Value*, uint32_t>& value_to_index);
void prepareFunctionsAndConstants(
Block* block,
const AliasDb& alias_db,
FastMap<const Value*, uint32_t>& value_to_index);
// Recurses on sub-blocks and populates the array of ProcessedNodes
// Returns (number of nodes processed, number of blocks processed)
size_t prepareProcessedNodes(
Block* block,
const FastMap<const Value*, uint32_t>& value_to_index,
const AliasDb& alias_db,
size_t node_idx = 0);
// Initialize various attributes that the memory planner will need. // Initialize various attributes that the memory planner will need.
// To be called at the tail of the ctor. // To be called at the tail of the ctor.
void prepareForMemoryPlanner(); void prepareForMemoryPlanner();
@ -383,15 +476,6 @@ class TORCH_API StaticModule {
// Indices of graph outputs in the single values array. // Indices of graph outputs in the single values array.
std::vector<uint16_t> output_indices_; std::vector<uint16_t> output_indices_;
ValueGroup value_group_;
FastSet<const Node*> node_is_optimizable_container_type_;
FastSet<const Value*> managed_tensor_values_{};
FastSet<const Value*> managed_output_tensor_values_{};
FastSet<const Value*> leaked_values_{};
ManagedTensorRanges managed_tensor_ranges_{};
size_t num_intermediate_values_ = 0; size_t num_intermediate_values_ = 0;
// Includes self if module_ != nullopt. // Includes self if module_ != nullopt.
@ -399,16 +483,33 @@ class TORCH_API StaticModule {
// argument. In this case, `self` isn't used in the graph, but the schema // argument. In this case, `self` isn't used in the graph, but the schema
// includes it anyways to be consistent with the JIT interpreter. // includes it anyways to be consistent with the JIT interpreter.
size_t num_inputs_; size_t num_inputs_;
// See `BlockInfo` definition. The blocks are stored in depth-first order.
FastMap<Block*, BlockInfo> block_infos_;
size_t value_buffer_size_ = 0;
}; };
class TORCH_API StaticRuntime { // `BlockRunner` contains the core runtime logic. Each block runner
// corresponds to one block in the graph and has its own memory planner.
// `StaticRuntime` will initialize all `BlockRunner`s
// upon construction. Each block runner only directly executes nodes from its
// block. Special ops with sub-blocks like `prim::If` may have
// `BlockRunner`s stored in their `ProcessedNode`s; these
// sub-blocks get executed in the op's implementation.
// `StaticRuntime` stores a vector of IValues that all
// `BlockRunner`s share. This vector is used to store all
// constants, inputs, and intermediate tensors.
class TORCH_API BlockRunner {
public: public:
explicit StaticRuntime(const StaticModule& sm); BlockRunner(
StaticRuntime(StaticRuntime&&) = delete; const StaticModule& sm,
StaticRuntime& operator=(StaticRuntime&&) = delete; std::vector<IValue>& values,
~StaticRuntime(); Block* block,
bool is_root_block = false);
BlockRunner(BlockRunner&&) noexcept;
BlockRunner& operator=(BlockRunner&&) = delete;
~BlockRunner();
C10_DISABLE_COPY_AND_ASSIGN(StaticRuntime); C10_DISABLE_COPY_AND_ASSIGN(BlockRunner);
using KeywordArgs = std::unordered_map<std::string, c10::IValue>; using KeywordArgs = std::unordered_map<std::string, c10::IValue>;
c10::IValue operator()( c10::IValue operator()(
@ -451,11 +552,16 @@ class TORCH_API StaticRuntime {
// Input is readwrite // Input is readwrite
IValue& Input(uint32_t i) { IValue& Input(uint32_t i) {
DCHECK_LT(i, static_module_.num_inputs()); DCHECK_LT(i, block_info_.num_inputs());
DCHECK_LT(i, values_.size()); DCHECK_LT(i, values_.size());
return values_[i]; return values_[i + block_info_.block_inputs_idx()];
} }
size_t init_sub_blocks(
const StaticModule& sm,
std::vector<IValue>& values,
size_t block_idx);
// Output is readonly. The writing process happens inside ProcessedNodes // Output is readonly. The writing process happens inside ProcessedNodes
C10_NODISCARD const IValue& Output(uint32_t i) const { C10_NODISCARD const IValue& Output(uint32_t i) const {
DCHECK(i < outputs_.size()); DCHECK(i < outputs_.size());
@ -475,7 +581,7 @@ class TORCH_API StaticRuntime {
} }
graph_node_list node_ptrs() const { graph_node_list node_ptrs() const {
return static_module_.node_ptrs(); return block_info_.node_ptrs();
} }
const Graph& graph() const { const Graph& graph() const {
@ -486,11 +592,9 @@ class TORCH_API StaticRuntime {
return planner_.get(); return planner_.get();
} }
bool check_for_memory_leak(bool output_returned = true); bool check_for_memory_leak(
bool output_returned = true,
bool is_optimizable_container_type(Node* n) const { bool recurse_on_sub_blocks = false);
return static_module_.is_optimizable_container_type(n);
}
// WARNING: Deallocate managed output tensors. A client receiving Static // WARNING: Deallocate managed output tensors. A client receiving Static
// Runtime-managed Tensors needs to be very careful to call // Runtime-managed Tensors needs to be very careful to call
@ -521,7 +625,8 @@ class TORCH_API StaticRuntime {
// when destructed. // when destructed.
class Deallocator { class Deallocator {
public: public:
explicit Deallocator(StaticRuntime& runtime) : runtime_(runtime) {} explicit Deallocator(BlockRunner& block_runner)
: block_runner_(block_runner) {}
Deallocator(Deallocator&&) = default; Deallocator(Deallocator&&) = default;
Deallocator(const Deallocator&) = default; Deallocator(const Deallocator&) = default;
@ -537,7 +642,7 @@ class TORCH_API StaticRuntime {
void cleanupImpl(); void cleanupImpl();
bool finished_ = false; bool finished_ = false;
StaticRuntime& runtime_; BlockRunner& block_runner_;
}; };
template <typename IValueList> template <typename IValueList>
@ -569,8 +674,8 @@ class TORCH_API StaticRuntime {
// clean up owning refs of input IValues // clean up owning refs of input IValues
void clean_up_input_ivalues() noexcept { void clean_up_input_ivalues() noexcept {
for (const auto idx : c10::irange(static_module_.num_inputs())) { for (const auto idx : c10::irange(block_info_.num_inputs())) {
values_[idx] = IValue(); values_[idx + inputs_begin_] = IValue();
} }
} }
@ -591,16 +696,29 @@ class TORCH_API StaticRuntime {
const KeywordArgs& kwargs); const KeywordArgs& kwargs);
const StaticModule& static_module_; const StaticModule& static_module_;
const BlockInfo& block_info_;
const bool is_root_block_;
// Cache this so we don't have to call static_module_.first_input_is_self() // Cache this so we don't have to call static_module_.first_input_is_self()
const bool first_input_is_self_; const bool first_input_is_self_;
// Index of the start of this blocks inputs in the shared values_ array.
const uint16_t inputs_begin_;
bool manage_output_tensors_enabled_ = false; bool manage_output_tensors_enabled_ = false;
std::unique_ptr<MemoryPlanner> planner_; std::unique_ptr<MemoryPlanner> planner_;
// first static_module_.num_inputs() slots are inputs, next // [Shared values array]
// static_module_.constants().size() slots are a copy of // ProcessedNodes reference their inputs and outputs with
// static_module_.constants(), rest are regular values in the
// graph. ProcessedNodes reference their inputs and outputs with
// offsets into this array, which saves memory. // offsets into this array, which saves memory.
std::vector<IValue> values_; // All BlockRunners share the same array. The layout is as
// follows:
// [constants][block_0][block_1]...[block_N]
// Note that constants from all blocks are pooled together at the start.
// The block ordering is depth-first.
// Each block is further divided into inputs and intermediates:
// [block_i] = [inputs_i][intermediates_i]
// Each BlockRunner knows where its inputs start. Each ProcessedNode
// knows how to find the indices of its outputs/inputs in this array.
std::vector<IValue>& values_;
std::vector<IValue*> outputs_; std::vector<IValue*> outputs_;
std::vector<ProcessedNode> nodes_; std::vector<ProcessedNode> nodes_;
}; };
@ -643,15 +761,55 @@ class TORCH_API ProcessedNode {
ProcessedNode() = default; ProcessedNode() = default;
// ProcessedNodes are created within StaticModule and then // ProcessedNodes are created within StaticModule and then
// associated with a shared values array using set_values() when // associated with a shared values array using set_values() when
// they are copied into a StaticRuntime. // they are copied into a StaticRuntime. block_runners_ are also
// not initialized until StaticRuntime initialization; see
// BlockRunner's ctor.
ProcessedNode( ProcessedNode(
Node* n, Node* n,
ProcessedFunction* fn, ProcessedFunction* fn,
ProcessedNodeInputs inputs, ProcessedNodeInputs inputs,
uint16_t outputs_offset); uint16_t outputs_offset);
ProcessedNode(const ProcessedNode&) = default; ProcessedNode(const ProcessedNode& other)
ProcessedNode& operator=(const ProcessedNode&) = default; : node_(other.node_),
fn_(other.fn_),
overlap_detected_(other.overlap_detected_),
inputs_(other.inputs_),
outputs_offset_(other.outputs_offset_),
num_outputs_(other.num_outputs_),
values_(other.values_),
// It doesn't really make sense to copy block runners,
// each processed node needs its own. This is OK to do
// since ProcessedNodes are copied from StaticModule right before
// the block runners are set up.
// TODO(T105178680): For this task, we should move
// block runners out of ProcessedNode. Then, we don't have to deal
// with this caveat.
block_runners_(nullptr)
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
,
op_name_(other.op_name_)
#endif
{
}
ProcessedNode& operator=(const ProcessedNode& other) {
if (&other == this) {
return *this;
}
node_ = other.node_;
fn_ = other.fn_;
overlap_detected_ = other.overlap_detected_;
inputs_ = other.inputs_;
outputs_offset_ = other.outputs_offset_;
num_outputs_ = other.num_outputs_;
values_ = other.values_;
block_runners_ = nullptr;
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
op_name_ = other.op_name_;
#endif
return *this;
}
// These should be noexcept, but some Android build is failing // These should be noexcept, but some Android build is failing
// saying the noexcept specification doesn't match the calculated // saying the noexcept specification doesn't match the calculated
@ -732,11 +890,21 @@ class TORCH_API ProcessedNode {
} }
C10_NODISCARD uint16_t output_ivalue_index(uint16_t i) const { C10_NODISCARD uint16_t output_ivalue_index(uint16_t i) const {
DCHECK(i < num_outputs_);
return outputs_offset_ + i; return outputs_offset_ + i;
} }
// used in debug mode // used in debug mode
bool verify_no_memory_overlap(bool force_check = false) const; bool verify_no_memory_overlap(bool force_check = false) const;
std::vector<BlockRunner>* block_runners() {
return block_runners_.get();
}
void set_block_runners(
std::unique_ptr<std::vector<BlockRunner>> block_runners) {
block_runners_ = std::move(block_runners);
}
private: private:
C10_NODISCARD bool verify_outputs_dont_overlap_each_other() const; C10_NODISCARD bool verify_outputs_dont_overlap_each_other() const;
@ -749,10 +917,74 @@ class TORCH_API ProcessedNode {
uint16_t outputs_offset_; uint16_t outputs_offset_;
uint16_t num_outputs_; uint16_t num_outputs_;
IValue* values_ = nullptr; // unowned IValue* values_ = nullptr; // unowned
// For control flow; processed nodes may have sub-blocks which can
// be executed by op implementations.
std::unique_ptr<std::vector<BlockRunner>> block_runners_;
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING #ifndef PYTORCH_DISABLE_PER_OP_PROFILING
const char* op_name_; const char* op_name_;
#endif #endif
}; };
// `StaticRuntime` is the owner of the array of IValues (used for constants,
// inputs, and intermediate tensors) that all `BlockRunner`s share.
// Upon construction, it initializes all block runners. `operator()` simply
// forwards the inputs to the top-level block runner. Each `StaticRuntime`
// instance corresponds to one `StaticModule`. Multiple `StaticRuntime`
// instances can be created; this is useful for multi-threaded execution, since
// `operator()` is not thread-safe.
class TORCH_API StaticRuntime {
public:
explicit StaticRuntime(const StaticModule& sm);
using KeywordArgs = std::unordered_map<std::string, c10::IValue>;
c10::IValue operator()(
const std::vector<c10::IValue>& args,
const KeywordArgs& kwargs = KeywordArgs());
c10::IValue operator()(
std::vector<c10::IValue>&& args,
const KeywordArgs& kwargs = KeywordArgs());
bool check_for_memory_leak(bool output_returned = true);
bool checkOutputTensorMemoryLeaks();
void deallocateOutputTensors();
bool isManagedOutputTensor(const IValue& ivalue) const;
void disableManageOutputTensors();
// Gets the top-level memory planner. Used for testing.
const MemoryPlanner* get_memory_planner() const;
void benchmark(
const std::vector<std::vector<c10::IValue>>& args_list,
const std::vector<KeywordArgs>& kwargs_list,
const int warmup_runs,
const int main_runs,
bool print_per_node_time = false,
bool generate_ai_pep_output = false) {
block_->benchmark(
args_list,
kwargs_list,
warmup_runs,
main_runs,
print_per_node_time,
generate_ai_pep_output);
}
using IndividualMetrics = BlockRunner::IndividualMetrics;
IndividualMetrics benchmark_individual_ops(
const std::vector<std::vector<c10::IValue>>& args_list,
const std::vector<KeywordArgs>& kwargs_list,
const int warmup_runs,
const int main_runs) {
return block_->benchmark_individual_ops(
args_list, kwargs_list, warmup_runs, main_runs);
}
private:
std::unique_ptr<BlockRunner> block_;
std::vector<IValue> values_;
};
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -129,10 +129,10 @@ bool setIncludes(const FastSet<const Value*>& set, const Value* v) {
} }
std::vector<std::pair<size_t, at::Tensor*>> assignStorageToOutputTensors( std::vector<std::pair<size_t, at::Tensor*>> assignStorageToOutputTensors(
StaticRuntime* runtime, BlockRunner* block_runner,
const FastSet<const Value*>& managed_output_tensor_values) { const FastSet<const Value*>& managed_output_tensor_values) {
std::vector<std::pair<size_t, at::Tensor*>> managed_output_tensors; std::vector<std::pair<size_t, at::Tensor*>> managed_output_tensors;
for (auto& pnode : runtime->nodes()) { for (auto& pnode : block_runner->nodes()) {
for (const auto i : c10::irange(pnode.outputs().size())) { for (const auto i : c10::irange(pnode.outputs().size())) {
auto& ival = pnode.Output(i); auto& ival = pnode.Output(i);
const auto* val = pnode.node()->outputs()[i]; const auto* val = pnode.node()->outputs()[i];
@ -151,19 +151,20 @@ std::vector<std::pair<size_t, at::Tensor*>> assignStorageToOutputTensors(
} // namespace } // namespace
MemoryPlanner::MemoryPlanner( MemoryPlanner::MemoryPlanner(
StaticRuntime* runtime, BlockRunner* block_runner,
const ValueGroup& value_group, const BlockInfo& block_info,
const FastSet<const Value*>& managed_tensor_values,
const FastSet<const Value*>& managed_output_tensor_values,
const FastSet<const Value*>& leaked_values,
const ManagedTensorRanges& ranges,
bool enable_out_variant, bool enable_out_variant,
bool manage_output_tensors, bool manage_output_tensors,
bool optimize_memory) { bool optimize_memory) {
const auto& managed_tensor_values = block_info.managed_tensor_values();
const auto& managed_output_tensor_values =
block_info.managed_output_tensor_values();
const auto& leaked_values = block_info.leaked_values();
// collect unmanaged output ivalues // collect unmanaged output ivalues
FastSet<IValue*> unmanaged_ivalues; FastSet<IValue*> unmanaged_ivalues;
FastSet<IValue*> unmanaged_borrowed_ivalues; FastSet<IValue*> unmanaged_borrowed_ivalues;
for (ProcessedNode& pnode : runtime->nodes()) { for (ProcessedNode& pnode : block_runner->nodes()) {
const auto borrows_outputs = borrowsOutputs(pnode.node()->kind()); const auto borrows_outputs = borrowsOutputs(pnode.node()->kind());
for (const auto i : c10::irange(pnode.outputs().size())) { for (const auto i : c10::irange(pnode.outputs().size())) {
const Value* out_v = pnode.node()->outputs()[i]; const Value* out_v = pnode.node()->outputs()[i];
@ -189,7 +190,7 @@ MemoryPlanner::MemoryPlanner(
} }
} }
} }
for (IValue* output : runtime->outputs()) { for (IValue* output : block_runner->outputs()) {
auto it = unmanaged_borrowed_ivalues.find(output); auto it = unmanaged_borrowed_ivalues.find(output);
if (it != unmanaged_borrowed_ivalues.end()) { if (it != unmanaged_borrowed_ivalues.end()) {
borrowed_ivalues_needing_incref_.push_back(output); borrowed_ivalues_needing_incref_.push_back(output);
@ -213,10 +214,12 @@ MemoryPlanner::MemoryPlanner(
if (enable_out_variant) { if (enable_out_variant) {
const auto tensor_value_to_tensor = const auto tensor_value_to_tensor =
tensorValueToTensor(runtime->nodes(), managed_tensor_values); tensorValueToTensor(block_runner->nodes(), managed_tensor_values);
if (optimize_memory) { if (optimize_memory) {
managed_tensors_ = assignStorageToManagedTensors( managed_tensors_ = assignStorageToManagedTensors(
runtime->node_ptrs(), ranges, tensor_value_to_tensor); block_info.node_ptrs(),
block_info.managed_tensor_ranges(),
tensor_value_to_tensor);
} else { } else {
for (auto& tensor : tensor_value_to_tensor) { for (auto& tensor : tensor_value_to_tensor) {
managed_tensors_.emplace_back(tensor.second); managed_tensors_.emplace_back(tensor.second);
@ -225,8 +228,8 @@ MemoryPlanner::MemoryPlanner(
} }
if (enable_out_variant && manage_output_tensors) { if (enable_out_variant && manage_output_tensors) {
managed_output_tensors_ = managed_output_tensors_ = assignStorageToOutputTensors(
assignStorageToOutputTensors(runtime, managed_output_tensor_values); block_runner, managed_output_tensor_values);
} }
num_managed_tensors_ = 0; num_managed_tensors_ = 0;

View File

@ -93,12 +93,8 @@ TORCH_API std::vector<StorageGroup> assignStorageToManagedTensors(
class MemoryPlanner { class MemoryPlanner {
public: public:
explicit MemoryPlanner( explicit MemoryPlanner(
StaticRuntime* runtime, BlockRunner* block_runner,
const ValueGroup& value_group, const BlockInfo& block_info,
const FastSet<const Value*>& managed_tensor_values,
const FastSet<const Value*>& managed_output_tensor_values,
const FastSet<const Value*>& leaked_values,
const ManagedTensorRanges& ranges,
bool enable_out_variant, bool enable_out_variant,
bool manage_output_tensors, bool manage_output_tensors,
bool optimize_memory); bool optimize_memory);