From 3a9feffd929869120c717d35aa55aad8a382783d Mon Sep 17 00:00:00 2001 From: Mike Iovine Date: Tue, 1 Feb 2022 09:08:41 -0800 Subject: [PATCH] [SR] Add BlockRunner and handle sub-blocks (#69834) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69834 * Modify the `StaticModule` constructor to handle index initialization for sub-blocks. * Add a new class `StaticRuntimeBlockRunner`. This class is almost exactly like what we've been calling `StaticRuntime` up to this point, except that it does not own a `values_` array. All `StaticRuntimeBlockRunners` hold an unowned reference to a `values_` array owned by `StaticRuntime`. This is a useful abstraction for implementing control flow - it gives us a way for sub-blocks to look up values from surrounding scopes! ghstack-source-id: 148086245 Test Plan: `buck test caffe2/benchmarks/static_runtime/...` Reviewed By: d1jang Differential Revision: D33028039 fbshipit-source-id: 4f01417bad51a0cf09b1680a518308da647be1f6 --- .../static_runtime/test_static_module.cc | 36 +- torch/csrc/jit/runtime/static/impl.cpp | 527 ++++++++++-------- torch/csrc/jit/runtime/static/impl.h | 434 +++++++++++---- .../jit/runtime/static/memory_planner.cpp | 31 +- .../csrc/jit/runtime/static/memory_planner.h | 8 +- 5 files changed, 680 insertions(+), 356 deletions(-) diff --git a/benchmarks/static_runtime/test_static_module.cc b/benchmarks/static_runtime/test_static_module.cc index f6486354834..81556435c08 100644 --- a/benchmarks/static_runtime/test_static_module.cc +++ b/benchmarks/static_runtime/test_static_module.cc @@ -106,7 +106,8 @@ TEST(StaticModule, ValueGroup) { torch::jit::StaticModule sm(input_graph); const Graph& graph = sm.graph(); std::vector nodes(graph.nodes().begin(), graph.nodes().end()); - const auto& value_group = sm.value_group(); + auto* root_block = sm.root_block(); + const auto& value_group = sm.block_info(root_block).value_group(); std::vector expected_input_aliases{ graph.inputs()[0], graph.inputs()[1], nodes[0]->output()}; @@ -138,9 +139,11 @@ TEST(StaticModule, IsOptimizableContainerType_NonOptimizableInputs) { auto sm = makeStaticModuleFromScript(src); const auto& graph = sm.graph(); + auto* root_block = sm.root_block(); + const auto& block_info = sm.block_info(root_block); for (const Node* n : graph.nodes()) { - EXPECT_FALSE(sm.is_optimizable_container_type(n)); + EXPECT_FALSE(block_info.node_is_optimizable_container_type(n)); } } @@ -158,9 +161,11 @@ TEST(StaticModule, IsOptimizableContainerType_WrongType) { auto sm = makeStaticModuleFromScript(src); const auto& graph = sm.graph(); + auto* root_block = sm.root_block(); + const auto& block_info = sm.block_info(root_block); for (const Node* n : graph.nodes()) { - EXPECT_FALSE(sm.is_optimizable_container_type(n)); + EXPECT_FALSE(block_info.node_is_optimizable_container_type(n)); } } @@ -175,12 +180,14 @@ TEST(StaticModule, IsOptimizableContainerType_CanUseOutVariant) { )JIT"; auto sm = makeStaticModuleFromScript(src); const auto& graph = sm.graph(); + auto* root_block = sm.root_block(); + const auto& block_info = sm.block_info(root_block); for (const Node* n : graph.nodes()) { if (n->kind() == c10::prim::ListConstruct) { - EXPECT_TRUE(sm.is_optimizable_container_type(n)); + EXPECT_TRUE(block_info.node_is_optimizable_container_type(n)); } else { - EXPECT_FALSE(sm.is_optimizable_container_type(n)); + EXPECT_FALSE(block_info.node_is_optimizable_container_type(n)); } } } @@ -1050,7 +1057,8 @@ TEST(ManagedTensorRanges, NoAliases) { auto* z = vmap["z"]; FastSet managed_tensors = {y, z}; - ManagedTensorRanges ranges(graph, managed_tensors); + AliasDb alias_db(graph); + auto ranges = ManagedTensorRanges(*graph->block(), alias_db, managed_tensors); std::vector nodes( graph->block()->nodes().begin(), graph->block()->nodes().end()); @@ -1089,7 +1097,8 @@ TEST(ManagedTensorRanges, AliasExtendingLifetimes) { auto* z2 = vmap["z2"]; FastSet managed_tensors = {y, z1, z2}; - ManagedTensorRanges ranges(graph, managed_tensors); + AliasDb alias_db(graph); + auto ranges = ManagedTensorRanges(*graph->block(), alias_db, managed_tensors); std::vector nodes( graph->block()->nodes().begin(), graph->block()->nodes().end()); @@ -1135,7 +1144,8 @@ TEST(ManagedTensorRanges, LifetimeOverlap) { auto* d = vmap["d"]; auto* e = vmap["e"]; - ManagedTensorRanges ranges(graph, {b, c, d, e}); + AliasDb alias_db(graph); + auto ranges = ManagedTensorRanges(*graph->block(), alias_db, {b, c, d, e}); const std::vector> overlapping_values{ {b, c}, {c, d}, {c, e}}; @@ -1169,7 +1179,8 @@ TEST(ManagedTensorRanges, OverlappingLifetimesContainers) { auto* c = vmap["c"]; auto* d = vmap["d"]; - ManagedTensorRanges ranges(graph, {b, c, d}); + AliasDb alias_db(graph); + auto ranges = ManagedTensorRanges(*graph->block(), alias_db, {b, c, d}); EXPECT_TRUE(ranges.lifetimesOverlap(b, c)); EXPECT_TRUE(ranges.lifetimesOverlap(b, d)); @@ -1189,7 +1200,8 @@ TEST(ManagedTensorRanges, OverlappingLifetimesOutputs) { auto* b = vmap["b"]; auto* output = vmap["output"]; - ManagedTensorRanges ranges(graph, {b, output}); + AliasDb alias_db(graph); + auto ranges = ManagedTensorRanges(*graph->block(), alias_db, {b, output}); EXPECT_TRUE(ranges.lifetimesOverlap(b, output)); } @@ -1275,7 +1287,9 @@ void testAssignStorageToManagedTensors( } ASSERT_EQ(managed_tensor_values.size(), tensor_value_to_tensor.size()); - auto ranges = ManagedTensorRanges(graph, managed_tensor_values); + AliasDb alias_db(graph); + auto ranges = + ManagedTensorRanges(*graph->block(), alias_db, managed_tensor_values); auto groups = assignStorageToManagedTensors( graph->block()->nodes(), ranges, tensor_value_to_tensor); diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index 6cafc8c4829..77b8672d331 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -32,6 +33,7 @@ #endif #include +#include #include #include @@ -58,8 +60,8 @@ bool isUnsupportedOp(const NodeKind& kind) { return kind == aten::__is__ || kind == aten::__isnot__; } -// graph must be frozen or canEnableStaticRuntime would return false if there's -// any prim::CallMethod op left in the graph +// graph must be frozen or canEnableStaticRuntime would return false +// if there's any prim::CallMethod op left in the graph bool canEnableStaticRuntime(const std::shared_ptr& graph) { // check for sub-blocks bool can_support = true; @@ -181,26 +183,20 @@ std::vector valueVecFromFastSet(const FastSet& s) { return result; } -bool mayContainAlias(AliasDb& db, const Value* a, const Value* b) { +bool mayContainAlias(const AliasDb& db, const Value* v1, const Value* v2) { + // AliasDb is not const-correct here, so we have to const_cast // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - return db.mayContainAlias(const_cast(a), const_cast(b)); + return db.mayContainAlias(const_cast(v1), const_cast(v2)); } bool mayContainAlias( - AliasDb& db, + const AliasDb& db, const Value* a, const FastSet& b) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) return db.mayContainAlias(const_cast(a), valueVecFromFastSet(b)); } -bool mayContainAlias( - AliasDb& db, - const FastSet& a, - const FastSet& b) { - return db.mayContainAlias(valueVecFromFastSet(a), valueVecFromFastSet(b)); -} - void PrepareGraphForStaticModule( std::shared_ptr graph, const StaticModuleOptions& opts, @@ -248,23 +244,21 @@ std::pair, c10::optional> PrepareForStaticModule( } // namespace -void ValueGroup::init( - const std::shared_ptr& graph, - AliasDb& db) { +void ValueGroup::init(const Block& block, const AliasDb& db) { external_aliases_.clear(); output_aliases_.clear(); // Build `external_aliases` as we look through nodes forwardly from // the graph's inputs and add aliases of the inputs being created by the // nodes. - external_aliases_.insert(graph->inputs().begin(), graph->inputs().end()); - for (const auto* node : graph->nodes()) { + external_aliases_.insert(block.inputs().begin(), block.inputs().end()); + for (const auto* node : block.nodes()) { if (node->kind() == prim::Constant) { for (const auto* output : node->outputs()) { external_aliases_.insert(output); } } } - for (const auto* node : graph->nodes()) { + for (const auto* node : block.nodes()) { if (node->kind() == prim::Constant) { // Constants are already in `external_aliases`. continue; @@ -278,8 +272,8 @@ void ValueGroup::init( // Build `output_aliases` as we look through nodes reversely so that we can // start from the output values, and follow the flows backwardly from there. - output_aliases_.insert(graph->outputs().begin(), graph->outputs().end()); - for (const auto* node : graph->nodes().reverse()) { + output_aliases_.insert(block.outputs().begin(), block.outputs().end()); + for (const auto* node : block.nodes().reverse()) { if (node->kind() == prim::Constant) { // Constants cannot create any aliases. continue; @@ -317,12 +311,6 @@ bool containTensorsOnly(at::ArrayRef values) { }); } -bool mayContainAlias(const Value* v1, const Value* v2, const AliasDb& db) { - // AliasDb is not const-correct here, so we have to const_cast - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - return db.mayContainAlias(const_cast(v1), const_cast(v2)); -} - bool isPureFunction(const Node* node) { auto* schema = node->maybeSchema(); return schema && @@ -332,12 +320,12 @@ bool isPureFunction(const Node* node) { } // namespace ManagedTensorRanges::ManagedTensorRanges( - const std::shared_ptr& graph, + Block& block, + const AliasDb& alias_db, const FastSet& managed_tensor_values) { - AliasDb alias_db(graph); - const std::vector nodes(graph->nodes().begin(), graph->nodes().end()); + const std::vector nodes(block.nodes().begin(), block.nodes().end()); const FastSet graph_inputs( - graph->inputs().begin(), graph->inputs().end()); + block.inputs().begin(), block.inputs().end()); auto isUntrackedValue = [&alias_db, &graph_inputs](const Value* value) { return !alias_db.isMutableType(value) || @@ -363,7 +351,7 @@ ManagedTensorRanges::ManagedTensorRanges( value_lifetimes_.emplace(output, Lifetime(i, i)); } } - for (auto* graph_output : graph->outputs()) { + for (auto* graph_output : block.outputs()) { auto* lifetime = getLifetime(graph_output); if (!lifetime) { DCHECK(isUntrackedValue(graph_output)); @@ -376,7 +364,7 @@ ManagedTensorRanges::ManagedTensorRanges( // has an input and output that may alias each other, set the input's // lifetime end to max(input.lifetime_end, output.lifetime_end). Iterate // backwards to handle chains of aliases. - for (const auto* node : graph->nodes().reverse()) { + for (const auto* node : block.nodes().reverse()) { if (isPureFunction(node)) { // If the node is a pure function, it doesn't create any aliases, // so we can safely skip it. @@ -389,7 +377,7 @@ ManagedTensorRanges::ManagedTensorRanges( auto* input_lifetime = getLifetime(input); DCHECK(input_lifetime != nullptr); for (auto* output : outputs) { - if (mayContainAlias(input, output, alias_db)) { + if (mayContainAlias(alias_db, input, output)) { auto* output_lifetime = getLifetime(output); DCHECK(output_lifetime != nullptr); input_lifetime->end = @@ -404,7 +392,7 @@ ManagedTensorRanges::ManagedTensorRanges( // NOLINTNEXTLINE(cppcoreguidelines-init-variables) Node* freeing_node; if (lifetime->end == num_nodes) { - freeing_node = graph->return_node(); + freeing_node = block.return_node(); } else { freeing_node = nodes[lifetime->end]; } @@ -519,15 +507,6 @@ StaticModule::StaticModule( } } - // map Value* to its SSA definition IR - FastMap value_to_ssa_def; - - // N inputs map to the first N entries in storage - for (const auto i : c10::irange(graph_->inputs().size())) { - Value* input = graph_->inputs()[i]; - value_to_ssa_def[input] = std::make_pair(INPUT_VALUE, i); - } - { size_t nodes_size = 0, constants_size = 0; for (Node* node : graph_->nodes()) { @@ -536,7 +515,6 @@ StaticModule::StaticModule( constants_.reserve(constants_size); functions_.reserve(nodes_size); - nodes_.reserve(nodes_size); } // Create ProcessedFunction instances first to freeze their addresses to pass @@ -544,13 +522,89 @@ StaticModule::StaticModule( AliasDb alias_db(graph_, /*isFrozen=*/false); GRAPH_DEBUG("AliasDb: ", alias_db.toString()); - // Construct constant and function nodes - for (Node* node : graph_->nodes()) { + // Maps each Value* in the graph to its index in the values_ array that will + // eventually be created by StaticRuntime. + FastMap value_to_index; + prepareFunctionsAndConstants(graph_->block(), alias_db, value_to_index); + + const auto constants_index_offset = 0; + const auto values_index_offset = constants_index_offset + constants().size(); + value_buffer_size_ = values_index_offset; + + value_buffer_size_ += + prepareBlockInfo(graph_->block(), values_index_offset, value_to_index); + + prepareProcessedNodes(graph_->block(), value_to_index, alias_db); + + for (auto& block_and_info : block_infos_) { + auto& block_info = block_and_info.second; + block_info.prepare_for_memory_planner(alias_db, opts); + } +} + +size_t StaticModule::prepareBlockInfo( + Block* block, + const size_t start_idx, + FastMap& value_to_index) { + block_infos_.emplace(block, BlockInfo(start_idx, *block)); + + const auto num_inputs = block->inputs().size(); + for (const auto i : c10::irange(num_inputs)) { + value_to_index.emplace(block->inputs()[i], start_idx + i); + } + auto cur_idx = start_idx + num_inputs; + + for (auto* node : block->nodes()) { + for (auto* sub_block : node->blocks()) { + cur_idx += prepareBlockInfo(sub_block, cur_idx, value_to_index); + } + + if (node->kind() == prim::Constant) { + continue; + } + + TORCH_CHECK( + cur_idx < (1 << 16), + "outputs offset in values table", + cur_idx, + " would overflow 2-byte index storage"); + + const auto num_outputs = node->outputs().size(); + for (const auto i : c10::irange(num_outputs)) { + value_to_index.emplace(node->outputs()[i], cur_idx + i); + } + cur_idx += num_outputs; + } + + std::vector output_indices; + output_indices.reserve(block->outputs().size()); + for (auto* output : block->outputs()) { + const auto output_idx = value_to_index.at(output); + TORCH_CHECK( + output_idx < (1 << 16), + "outputs offset in values table", + output_idx, + " would overflow 2-byte index storage"); + output_indices.push_back(output_idx); + } + + block_infos_.at(block).set_output_indices(std::move(output_indices)); + return cur_idx - start_idx; +} + +void StaticModule::prepareFunctionsAndConstants( + Block* block, + const AliasDb& alias_db, + FastMap& value_to_index) { + for (auto* node : block->nodes()) { + for (auto* sub_block : node->blocks()) { + prepareFunctionsAndConstants(sub_block, alias_db, value_to_index); + } + if (node->kind() == prim::Constant) { auto* v = node->output(); TORCH_CHECK(v->type()->kind() != FunctionType::Kind); - // construct SSA definition for constant nodes - value_to_ssa_def[v] = std::make_pair(CONSTANT_VALUE, constants_.size()); + value_to_index.emplace(v, constants_.size()); constants_.emplace_back(toIValue(v).value()); continue; } @@ -561,66 +615,34 @@ StaticModule::StaticModule( containTensorsOnly(node->outputs()); // new ProcessedFunction functions_.emplace_back( - node, opts.enable_out_variant, check_outputs_for_overlap); + node, opts_.enable_out_variant, check_outputs_for_overlap); } +} - // construct SSA definition for non-constant nodes - int node_idx = 0; +size_t StaticModule::prepareProcessedNodes( + Block* block, + const FastMap& value_to_index, + const AliasDb& alias_db, + size_t node_idx) { + const auto node_start = node_idx; + + auto& block_info = block_infos_.at(block); + std::vector nodes; FastMap node_has_out_variant; - const auto inputs_index_offset = inputs_offset(); - const auto constants_index_offset = constants_offset(); - const auto values_index_offset = intermediate_values_offset(); - - // Map node_idx to index offset in values_. Can't reserve space - // because we don't know how many non-constant nodes there are yet. - std::vector node_output_idx_map; - uint32_t node_outputs_seen_so_far = 0; - for (Node* node : graph_->nodes()) { + for (auto* node : block->nodes()) { if (node->kind() == prim::Constant) { continue; } - // Assign memory for the outputs - const auto outputs_offset_for_node = - node_outputs_seen_so_far + values_index_offset; - TORCH_CHECK( - outputs_offset_for_node < (1 << 16), - "outputs offset in values table", - outputs_offset_for_node, - " would overflow 2-byte index storage"); - node_output_idx_map.push_back(outputs_offset_for_node); - node_outputs_seen_so_far += node->outputs().size(); - } - for (Node* node : graph_->nodes()) { - if (node->kind() == prim::Constant) { - continue; + for (auto* sub_block : node->blocks()) { + node_idx += + prepareProcessedNodes(sub_block, value_to_index, alias_db, node_idx); } ProcessedNodeInputs input_indices(node->inputs().size()); - std::vector input_ssa_defs; for (const auto input_idx : c10::irange(node->inputs().size())) { - Value* const input = node->inputs()[input_idx]; - int inner_node_idx = 0; - int out_idx = 0; - std::tie(inner_node_idx, out_idx) = value_to_ssa_def.at(input); - unsigned int input_ivalue_idx = 0; - if (inner_node_idx == StaticModule::INPUT_VALUE) { - input_ivalue_idx = out_idx + inputs_index_offset; - } else if (inner_node_idx == StaticModule::CONSTANT_VALUE) { - input_ivalue_idx = out_idx + constants_index_offset; - } else { - DCHECK_GE(inner_node_idx, 0); - const auto global_value_idx = - node_output_idx_map[inner_node_idx] + out_idx; - if (inner_node_idx < node_output_idx_map.size() - 1) { - DCHECK_LT(global_value_idx, node_output_idx_map[inner_node_idx + 1]); - } else { - DCHECK_LT( - global_value_idx, - constants_index_offset + node_outputs_seen_so_far); - } - input_ivalue_idx = global_value_idx; - } + auto* input = node->inputs()[input_idx]; + auto input_ivalue_idx = value_to_index.at(input); TORCH_CHECK( input_ivalue_idx < (1 << 16), "input index in values table ", @@ -630,72 +652,48 @@ StaticModule::StaticModule( } ProcessedFunction* fn = &functions_[node_idx]; + // create a new ProcessedNode - // see [Check and correct bad schema alias info at runtime] - bool check_outputs_for_overlap = - !alias_db.mayContainAlias(node->inputs(), node->outputs()) && - containTensorsOnly(node->outputs()); - nodes_.emplace_back( - node, fn, std::move(input_indices), node_output_idx_map[node_idx]); + const auto node_output_idx = node->outputs().empty() + // The index is unused if there are no outputs, so just create a + // placeholder value. + ? std::numeric_limits::max() + : value_to_index.at(node->output(0)); + nodes.emplace_back(node, fn, std::move(input_indices), node_output_idx); - node_has_out_variant.emplace(node, nodes_.back().has_out_variant()); - for (const auto i : c10::irange(node->outputs().size())) { - value_to_ssa_def[node->outputs()[i]] = std::make_pair(node_idx, i); - } - node_idx++; + node_has_out_variant.emplace(node, nodes.back().has_out_variant()); + ++node_idx; } - num_intermediate_values_ = std::accumulate( - nodes_.begin(), - nodes_.end(), - 0, - [](uint32_t sum, const ProcessedNode& pnode) { - return sum + pnode.num_outputs(); - }); + block_info.set_nodes(std::move(nodes), node_has_out_variant); + block_info.init_value_group(alias_db); - for (auto& pnode : nodes_) { - if (pnode.num_outputs() == 1 && - isOptimizableContainerType(pnode.node(), node_has_out_variant)) { - node_is_optimizable_container_type_.emplace(pnode.node()); - } - } - output_indices_.reserve(graph_->outputs().size()); - for (auto output : graph_->outputs()) { - int node_idx = 0; - int out_idx = 0; - std::tie(node_idx, out_idx) = value_to_ssa_def[output]; - uint32_t output_index = 0; - if (node_idx == StaticModule::INPUT_VALUE) { - output_index = out_idx + inputs_index_offset; - } else if (node_idx == StaticModule::CONSTANT_VALUE) { - output_index = constants_index_offset + out_idx; - } else { - output_index = nodes_[node_idx].output_ivalue_index(out_idx); - } - TORCH_CHECK( - output_index < (1 << 16), - "output index ", - output_index, - " would overflow 2-byte index storage"); - output_indices_.emplace_back(output_index); - } - - // Prepare for memory planning - value_group_.init(graph_, alias_db); - GRAPH_DEBUG(value_group_.toString()); - - prepareForMemoryPlanner(); + return node_idx - node_start; } -void StaticModule::prepareForMemoryPlanner() { - if (!opts_.enable_out_variant) { +void BlockInfo::set_nodes( + std::vector nodes, + const FastMap& node_has_out_variant) { + nodes_ = std::move(nodes); + + for (auto& node : nodes_) { + if (node.num_outputs() == 1 && + isOptimizableContainerType(node.node(), node_has_out_variant)) { + node_is_optimizable_container_type_.emplace(node.node()); + } + } +} +void BlockInfo::prepare_for_memory_planner( + const AliasDb& alias_db, + const StaticModuleOptions& opts) { + if (!opts.enable_out_variant) { return; } // Never manage graph outputs so that we can do std::move(output_ivalue). // This does not affect performance if the graph returns a collection object. FastSet graph_output_values( - graph_->outputs().begin(), graph_->outputs().end()); + block_.outputs().begin(), block_.outputs().end()); // collect register indices of outputs of ops with out variant for (ProcessedNode& pnode : nodes_) { @@ -707,7 +705,7 @@ void StaticModule::prepareForMemoryPlanner() { const Value* out_v = outputs[i]; // Types are stored in the underlying TorchScript IR bool is_tensor_type = out_v->type()->castRaw(); - if (opts_.manage_output_tensors && is_tensor_type && + if (opts.manage_output_tensors && is_tensor_type && graph_output_values.find(out_v) == graph_output_values.end() && value_group_.isOutputAlias(out_v)) { managed_output_tensor_values_.insert(out_v); @@ -718,7 +716,7 @@ void StaticModule::prepareForMemoryPlanner() { } if (is_tensor_type) { managed_tensor_values_.insert(out_v); - } else if (is_optimizable_container_type(pnode.node())) { + } else if (node_is_optimizable_container_type(pnode.node())) { // We "leak" certain container types because their allocations // take a long time leaked_values_.insert(out_v); @@ -726,7 +724,7 @@ void StaticModule::prepareForMemoryPlanner() { } } - for (const Value* output : graph_->outputs()) { + for (const Value* output : block_.outputs()) { managed_tensor_values_.erase(output); } GRAPH_DEBUG("managed_tensor_values: ", dumpValueSet(managed_tensor_values_)); @@ -734,7 +732,8 @@ void StaticModule::prepareForMemoryPlanner() { "managed_output_tensor_values_: ", dumpValueSet(managed_output_tensor_values_)); - managed_tensor_ranges_ = ManagedTensorRanges(graph_, managed_tensor_values_); + managed_tensor_ranges_ = + ManagedTensorRanges(block_, alias_db, managed_tensor_values_); } const StaticModuleOptions& StaticModule::opts() const { @@ -757,9 +756,12 @@ StaticRuntime& StaticModule::runtime() { } Node* StaticModule::findNodeWithKindForTesting(const std::string& kind) const { - for (auto& pnode : nodes()) { - if (pnode.node()->kind().toQualString() == kind) { - return pnode.node(); + for (auto& block_and_info : block_infos_) { + auto& block_info = block_and_info.second; + for (auto& pnode : block_info.nodes()) { + if (pnode.node()->kind().toQualString() == kind) { + return pnode.node(); + } } } return nullptr; @@ -777,41 +779,64 @@ c10::IValue StaticModule::operator()( return runtime()(std::move(args), kwargs); } -StaticRuntime::StaticRuntime(const StaticModule& sm) +BlockRunner::BlockRunner( + const StaticModule& sm, + std::vector& values, + Block* block, + bool is_root_block) : static_module_(sm), - first_input_is_self_(static_module_.first_input_is_self()), - manage_output_tensors_enabled_(sm.opts().manage_output_tensors), - nodes_(sm.nodes()) { - values_.resize(sm.total_num_values()); - const auto constants_index_offset = sm.constants_offset(); - const auto constants_begin_it = values_.begin() + constants_index_offset; - const auto constants_end_it = constants_begin_it + sm.constants().size(); - std::copy(sm.constants().begin(), sm.constants().end(), constants_begin_it); - - for (const auto idx : c10::irange(sm.nodes().size())) { - auto& n = nodes_[idx]; + block_info_(static_module_.block_info(block)), + is_root_block_(is_root_block), + first_input_is_self_( + is_root_block_ && static_module_.first_input_is_self()), + inputs_begin_(block_info_.block_inputs_idx()), + // TODO(T108633124): Turn on manage output tensors for sub-blocks. + manage_output_tensors_enabled_( + is_root_block_ && sm.opts().manage_output_tensors), + values_(values), + nodes_(block_info_.nodes()) { + for (auto& n : nodes_) { n.set_values(values_.data()); } - // TODO: can we convert outputs_ to store indices? - for (auto index : sm.output_indices()) { + for (auto index : block_info_.block_output_indices()) { outputs_.emplace_back(&values_[index]); } + + for (auto& pnode : nodes_) { + auto* node = pnode.node(); + auto blocks = node->blocks(); + const auto num_blocks = blocks.size(); + if (num_blocks == 0) { + continue; + } + DCHECK(node->kind() == prim::If || node->kind() == prim::Loop); + auto block_runners = std::make_unique>(); + block_runners->reserve(num_blocks); + + for (auto* b : blocks) { + block_runners->emplace_back(sm, values, b); + } + pnode.set_block_runners(std::move(block_runners)); + } } -StaticRuntime::~StaticRuntime() = default; +// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) +BlockRunner::BlockRunner(BlockRunner&&) noexcept = default; -void StaticRuntime::set_arg(const size_t idx, std::vector&& args) { +BlockRunner::~BlockRunner() = default; + +void BlockRunner::set_arg(const size_t idx, std::vector&& args) { DCHECK(idx < args.size()); Input(idx + first_input_is_self_) = std::move(args[idx]); } -void StaticRuntime::set_arg(const size_t idx, const std::vector& args) { +void BlockRunner::set_arg(const size_t idx, const std::vector& args) { DCHECK(idx < args.size()); Input(idx + first_input_is_self_) = args[idx]; } -void StaticRuntime::set_arg(const size_t idx, const IValue& arg) { +void BlockRunner::set_arg(const size_t idx, const IValue& arg) { Input(idx + first_input_is_self_) = arg; } @@ -827,24 +852,21 @@ void check_type(const Argument& schema_arg, const IValue& arg) { } // namespace template -void StaticRuntime::set_inputs( +void BlockRunner::set_inputs( IValueList&& args, const std::unordered_map& kwargs) { const auto total_num_inputs = args.size() + kwargs.size() + first_input_is_self_; - TORCH_CHECK(total_num_inputs == static_module_.num_inputs()); + TORCH_CHECK(total_num_inputs == block_info_.num_inputs()); const auto& schema = static_module_.schema(); if (first_input_is_self_) { Input(0) = static_module_.module()._ivalue(); } - if (C10_UNLIKELY(!schema)) { + if (!is_root_block_ || C10_UNLIKELY(!schema)) { TORCH_CHECK( - kwargs.empty(), - "Schema is not available, but StaticRuntime got kwargs. " - "Consider creating the Static Runtime instance " - "with StaticModule(const torch::jit::Module& m) instead."); + kwargs.empty(), "Schema is not available, but BlockRunner got kwargs."); for (size_t i = 0; i < args.size(); ++i) { set_arg(i, std::forward(args)); } @@ -887,15 +909,11 @@ void StaticRuntime::set_inputs( args.size() + consumed_kwargs == schema_args.size() - 1); } -void StaticRuntime::create_memory_planner() { +void BlockRunner::create_memory_planner() { if (!planner_) { planner_ = std::make_unique( this, - static_module_.value_group(), - static_module_.managed_tensor_values(), - static_module_.managed_output_tensor_values(), - static_module_.leaked_values(), - static_module_.managed_tensor_ranges(), + block_info_, static_module_.opts().enable_out_variant, manage_output_tensors_enabled_, static_module_.opts().optimize_memory); @@ -924,7 +942,7 @@ void destroyNodeOutputs(ProcessedNode& p_node) { } // namespace -void StaticRuntime::clean_up_intermediate_ivalues() noexcept { +void BlockRunner::clean_up_intermediate_ivalues() noexcept { // We have to iterate in reverse order here due to borrowed // IValues - we don't want to destroy a value until all of its // borrows are cleaned up! @@ -933,7 +951,7 @@ void StaticRuntime::clean_up_intermediate_ivalues() noexcept { } } -void StaticRuntime::resetMemory() noexcept { +void BlockRunner::resetMemory() noexcept { planner_.reset(); // We must clean up intermediate values before inputs in case // there are borrowed inputs and static runtime owns the only @@ -942,7 +960,7 @@ void StaticRuntime::resetMemory() noexcept { clean_up_input_ivalues(); } -c10::IValue StaticRuntime::move_outputs_to_tuple(uint32_t num_outputs) { +c10::IValue BlockRunner::move_outputs_to_tuple(uint32_t num_outputs) { switch (num_outputs) { case 1: return c10::ivalue::Tuple::create(IValue(std::move(*outputs_[0]))); @@ -1032,7 +1050,7 @@ c10::IValue StaticRuntime::move_outputs_to_tuple(uint32_t num_outputs) { /// buffer) fails. There is still a corner case that fails with the added flag. /// If a resize is triggered at the same time as the op creating an alias at the /// same time, the current checks would fail to detect the alias. -void StaticRuntime::verify_and_correct_memory_overlap(ProcessedNode& n) { +void BlockRunner::verify_and_correct_memory_overlap(ProcessedNode& n) { // The slow check can be removed once the internal/output buffers are merged if (C10_UNLIKELY(n.check_outputs_for_memory_overlap())) { if (C10_UNLIKELY(!planner_)) { @@ -1065,7 +1083,7 @@ void StaticRuntime::verify_and_correct_memory_overlap(ProcessedNode& n) { } } -bool StaticRuntime::fast_check_and_correct_overlap_with( +bool BlockRunner::fast_check_and_correct_overlap_with( ProcessedNode& n, c10::IValue& tensor_ival) { auto& tensor = tensor_ival.toTensor(); @@ -1078,38 +1096,38 @@ bool StaticRuntime::fast_check_and_correct_overlap_with( return false; } -StaticRuntime::Deallocator::~Deallocator() { +BlockRunner::Deallocator::~Deallocator() { // Assume cleanup cannot throw. cleanupImpl(); #ifndef NDEBUG - runtime_.check_for_memory_leak(/*output_returned*/ false); + block_runner_.check_for_memory_leak(/*output_returned*/ false); #endif } -void StaticRuntime::Deallocator::cleanupImpl() { +void BlockRunner::Deallocator::cleanupImpl() { // MemoryPlanner is created after the first invocation of `run()`. This // is done intentionally because MemoryPlanner uses `Tensor` sizes of // the previous `run()` for memory planning of subsequent runs if (C10_LIKELY(finished_)) { - runtime_.create_memory_planner(); + block_runner_.create_memory_planner(); } - if (C10_LIKELY(runtime_.planner_)) { - runtime_.planner_->deallocate(); + if (C10_LIKELY(block_runner_.planner_)) { + block_runner_.planner_->deallocate(); } else { // This is the first run, and it didn't finish, so we can't use a // `MemoryPlanner` to deallocate stuff. Just reset everything mannually. - runtime_.resetMemory(); + block_runner_.resetMemory(); } // clean up owning refs of input tensors - runtime_.clean_up_input_ivalues(); + block_runner_.clean_up_input_ivalues(); if (C10_UNLIKELY(!finished_)) { - runtime_.deallocateOutputTensors(); + block_runner_.deallocateOutputTensors(); } } template -c10::IValue StaticRuntime::run_impl( +c10::IValue BlockRunner::run_impl( IValueList&& args, const KeywordArgs& kwargs) { // We assume inference workloads, so we do not need @@ -1138,8 +1156,8 @@ c10::IValue StaticRuntime::run_impl( } // no need to keep references of outputs in static runtime anymore - if (static_module_.num_outputs() > 1) { - return move_outputs_to_tuple(static_module_.num_outputs()); + if (block_info_.num_outputs() > 1) { + return move_outputs_to_tuple(block_info_.num_outputs()); } DCHECK(check_for_memory_leak(/*output_returned*/ false)); @@ -1149,7 +1167,7 @@ c10::IValue StaticRuntime::run_impl( } template -c10::IValue StaticRuntime::run_impl_record_functions( +c10::IValue BlockRunner::run_impl_record_functions( IValueList&& args, const KeywordArgs& kwargs) { bool pre_sampled = false; @@ -1168,7 +1186,7 @@ c10::IValue StaticRuntime::run_impl_record_functions( return run_impl(std::forward(args), kwargs); } -c10::IValue StaticRuntime::operator()( +c10::IValue BlockRunner::operator()( const std::vector& args, const KeywordArgs& kwargs) { #ifdef PYTORCH_DISABLE_NET_PROFILING @@ -1178,7 +1196,7 @@ c10::IValue StaticRuntime::operator()( #endif } -c10::IValue StaticRuntime::operator()( +c10::IValue BlockRunner::operator()( std::vector&& args, const KeywordArgs& kwargs) { #ifdef PYTORCH_DISABLE_NET_PROFILING @@ -1205,7 +1223,7 @@ std::string generate_latency_json(const std::string& label, double millis) { } // namespace -void StaticRuntime::benchmark( +void BlockRunner::benchmark( const std::vector>& args_list, const std::vector& kwargs_list, const int warmup_runs, @@ -1267,7 +1285,7 @@ void StaticRuntime::benchmark( } std::cout << std::setw(15) << results.total_time << " ms. in Total" << std::endl; - std::cout << "StaticRuntime setup time: " << results.setup_time << " ms" + std::cout << "BlockRunner setup time: " << results.setup_time << " ms" << std::endl; std::cout << "Memory allocation time: " << results.memory_alloc_time << " ms\n"; @@ -1312,7 +1330,7 @@ void StaticRuntime::benchmark( #endif } -float StaticRuntime::benchmark_model( +float BlockRunner::benchmark_model( const std::vector>& args_list, const std::vector& kwargs_list, const int warmup_runs, @@ -1396,7 +1414,7 @@ void display_pnode_info(const ProcessedNode& pnode) { } } -void StaticRuntime::display_nodes( +void BlockRunner::display_nodes( const std::vector& args, const KeywordArgs& kwargs) { c10::InferenceMode mode; @@ -1415,7 +1433,7 @@ void StaticRuntime::display_nodes( on_exit.setFinished(); } -StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops( +BlockRunner::IndividualMetrics BlockRunner::benchmark_individual_ops( const std::vector>& args_list, const std::vector& kwargs_list, const int warmup_runs, @@ -1543,10 +1561,16 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops( return results; } -bool StaticRuntime::check_for_memory_leak(bool output_returned) { +bool BlockRunner::check_for_memory_leak( + bool output_returned, + bool recurse_on_sub_blocks) { // check for inputs - for (const auto i : c10::irange(static_module_.num_inputs())) { - TORCH_CHECK(values_[i].isNone(), "Input ", i, " was not cleaned up"); + for (const auto i : c10::irange(block_info_.num_inputs())) { + TORCH_CHECK( + values_[i + block_info_.block_inputs_idx()].isNone(), + "Input ", + i, + " was not cleaned up"); } FastSet output_ivalues(outputs_.begin(), outputs_.end()); for (const auto n : c10::irange(nodes_.size())) { @@ -1561,7 +1585,8 @@ bool StaticRuntime::check_for_memory_leak(bool output_returned) { (isManagedOutputTensor(*ival) || isManagedOutputTensorValue(val))) { // `ival` contains a managed output tensor that the runtime doesn't // reclaim at the end of an iteration, but the client does so - // by explicitly calling `StaticRuntime::deallocateOutputTensors`. + // by explicitly calling + // `BlockRunner::deallocateOutputTensors`. continue; } const std::string error_msg = "Output " + c10::to_string(i) + ", %" + @@ -1573,7 +1598,8 @@ bool StaticRuntime::check_for_memory_leak(bool output_returned) { if (!ival->isNone()) { TORCH_CHECK( ival->isTensor() || - static_module_.is_optimizable_container_type(pnode.node()) || + block_info_.node_is_optimizable_container_type( + pnode.node()) || doesNotHeapAllocateWhenStoredInIValue(*val->type()), error_msg); if (ival->isTensor()) { @@ -1595,12 +1621,20 @@ bool StaticRuntime::check_for_memory_leak(bool output_returned) { } } } + + auto* block_runners = pnode.block_runners(); + if (recurse_on_sub_blocks && block_runners) { + for (auto& block_runner : *block_runners) { + block_runner.check_for_memory_leak( + output_returned, recurse_on_sub_blocks); + } + } } VLOG(1) << "Finished checking for memory leak"; return true; } -void StaticRuntime::deallocateOutputTensors() { +void BlockRunner::deallocateOutputTensors() { if (!static_module_.opts().manage_output_tensors) { TORCH_CHECK( !planner_ || planner_->numOutputBufferBytes() == 0, @@ -1613,7 +1647,7 @@ void StaticRuntime::deallocateOutputTensors() { } } -bool StaticRuntime::checkOutputTensorMemoryLeaks() { +bool BlockRunner::checkOutputTensorMemoryLeaks() { if (!static_module_.opts().manage_output_tensors || !planner_) { return true; } @@ -1639,21 +1673,21 @@ bool StaticRuntime::checkOutputTensorMemoryLeaks() { return true; } -bool StaticRuntime::isManagedOutputTensor(const IValue& ivalue) const { +bool BlockRunner::isManagedOutputTensor(const IValue& ivalue) const { return planner_ && planner_->isManagedOutputTensor(ivalue); } -bool StaticRuntime::isManagedOutputTensorValue(const Value* value) const { +bool BlockRunner::isManagedOutputTensorValue(const Value* value) const { // It's possible that manage_output_tensors_ was disabled after initializing // managed_output_tensor_values, so we have to check that flag here. if (!planner_ || !manage_output_tensors_enabled_) { return false; } - const auto& managed_outputs = static_module_.managed_output_tensor_values(); + const auto& managed_outputs = block_info_.managed_output_tensor_values(); return managed_outputs.find(value) != managed_outputs.end(); } -void StaticRuntime::disableManageOutputTensors() { +void BlockRunner::disableManageOutputTensors() { if (!manage_output_tensors_enabled_) { return; } @@ -1915,5 +1949,50 @@ void ProcessedNode::verify_and_correct_memory_overlap() { } } +StaticRuntime::StaticRuntime(const StaticModule& sm) { + values_.resize(sm.value_buffer_size()); + std::copy(sm.constants().begin(), sm.constants().end(), values_.begin()); + block_ = std::make_unique( + sm, values_, sm.root_block(), /*is_root_block*/ true); + ; +} + +c10::IValue StaticRuntime::operator()( + const std::vector& args, + const KeywordArgs& kwargs) { + return (*block_)(args, kwargs); +} + +c10::IValue StaticRuntime::operator()( + std::vector&& args, + const KeywordArgs& kwargs) { + return (*block_)(std::move(args), kwargs); +} + +bool StaticRuntime::check_for_memory_leak(bool output_returned) { + return block_->check_for_memory_leak( + output_returned, /* recurse_on_sub_blocks */ true); +} + +bool StaticRuntime::checkOutputTensorMemoryLeaks() { + return block_->checkOutputTensorMemoryLeaks(); +} + +void StaticRuntime::deallocateOutputTensors() { + block_->deallocateOutputTensors(); +} + +bool StaticRuntime::isManagedOutputTensor(const IValue& ivalue) const { + return block_->isManagedOutputTensor(ivalue); +} + +void StaticRuntime::disableManageOutputTensors() { + block_->disableManageOutputTensors(); +} + +const MemoryPlanner* StaticRuntime::get_memory_planner() const { + return block_->get_memory_planner(); +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h index 5659aab012a..6beda9ce564 100644 --- a/torch/csrc/jit/runtime/static/impl.h +++ b/torch/csrc/jit/runtime/static/impl.h @@ -13,6 +13,7 @@ #include #include #include +#include #ifdef FBCODE_CAFFE2 #include @@ -82,7 +83,7 @@ TORCH_API inline bool borrowsOutputs(c10::Symbol kind) { class ValueGroup { public: explicit ValueGroup() = default; - void init(const std::shared_ptr& graph, AliasDb& db); + void init(const Block& block, const AliasDb& db); bool isExternalAlias(const Value* value) const { return external_aliases_.find(value) != external_aliases_.end(); @@ -112,7 +113,8 @@ class TORCH_API ManagedTensorRanges { public: ManagedTensorRanges() = default; ManagedTensorRanges( - const std::shared_ptr& graph, + Block& block, + const AliasDb& alias_db, const FastSet& managed_tensor_values); // If true, then this node is the last use of at least one @@ -213,11 +215,122 @@ struct TORCH_API StaticModuleOptions { /// pool.push(runtime); /// @endcode /// - class MemoryPlanner; class ProcessedFunction; class ProcessedNode; class StaticRuntime; + +// A `BlockInfo` instance stores all of the shared state that each +// `BlockRunner` will need to access. Most of this information is +// read-only and shared between threads. +// - Each `BlockInfo` corresponds to one block in the graph. +// - Each `BlockInfo` may be used by multiple block runners (when there are many +// threads). +// - All of the `BlockInfo`s are stored in a vector in the `StaticModule` and +// are initialized during `StaticModule` construction. +// - Most of the information stored is used to initialize the block's memory +// planner. +class BlockInfo { + public: + BlockInfo(uint32_t input_idx, Block& block) + : input_idx_(input_idx), block_(block) {} + + void set_nodes( + std::vector nodes, + const FastMap& node_has_out_variant); + + const std::vector& nodes() const { + return nodes_; + } + + size_t num_nodes() const { + return nodes_.size(); + } + + size_t num_inputs() const { + return block_.inputs().size(); + } + + size_t num_outputs() const { + return block_.outputs().size(); + } + + graph_node_list node_ptrs() const { + return block_.nodes(); + } + + void set_output_indices(std::vector indices) { + output_indices_ = std::move(indices); + } + + const std::vector& block_output_indices() const { + return output_indices_; + } + + auto block_inputs_idx() const { + return input_idx_; + } + + bool node_is_optimizable_container_type(const Node* node) const { + return node_is_optimizable_container_type_.find(node) != + node_is_optimizable_container_type_.end(); + } + + bool value_is_managed_tensor(const Value* value) const { + return managed_tensor_values_.find(value) != managed_tensor_values_.end(); + } + + bool value_is_leaked_container(const Value* value) const { + return leaked_values_.find(value) != leaked_values_.end(); + } + + const ValueGroup& value_group() const { + return value_group_; + } + + const ManagedTensorRanges& managed_tensor_ranges() const { + return managed_tensor_ranges_; + } + + void init_value_group(const AliasDb& alias_db) { + value_group_.init(block_, alias_db); + } + + void prepare_for_memory_planner( + const AliasDb& alias_db, + const StaticModuleOptions& opt); + + const auto& managed_output_tensor_values() const { + return managed_output_tensor_values_; + } + + const auto& managed_tensor_values() const { + return managed_tensor_values_; + } + + const auto& leaked_values() const { + return leaked_values_; + } + + private: + std::vector nodes_; + + ValueGroup value_group_; + + FastSet node_is_optimizable_container_type_; + FastSet managed_tensor_values_; + FastSet managed_output_tensor_values_; + FastSet leaked_values_; + + ManagedTensorRanges managed_tensor_ranges_{}; + + // The index of this block's inputs in the shared values_ array. + const uint16_t input_idx_; + // The indices of this block's outputs in the shared values_ array. + std::vector output_indices_; + Block& block_; +}; + class TORCH_API StaticModule { public: explicit StaticModule( @@ -231,23 +344,12 @@ class TORCH_API StaticModule { const StaticModuleOptions& opts = StaticModuleOptions(), std::vector sample_inputs = {}); - typedef enum { - CONSTANT_VALUE = -2, // VALUE nodes defined by prim::Constant - INPUT_VALUE = -1, // VALUE nodes representing graph inputs - } VALUE_KIND; - private: explicit StaticModule( std::pair, c10::optional> graph_and_module, const StaticModuleOptions& opts); - // for - // if kind == CONSTANT_VALUE: map to constants_[idx] - // if kind == INPUT_VALUE: map to inputs_[idx] - // otherwise: map to nodes_[kind].outputs()[idx] - using DefInfo = std::pair; - public: using KeywordArgs = std::unordered_map; c10::IValue operator()( @@ -268,10 +370,6 @@ class TORCH_API StaticModule { const StaticModuleOptions& opts() const; - const ValueGroup& valueGroup() const { - return value_group_; - } - size_t num_inputs() const; size_t num_outputs() const; @@ -295,74 +393,69 @@ class TORCH_API StaticModule { return constants_; } + const BlockInfo& block_info(Block* block) const { + return block_infos_.at(block); + } + + Block* root_block() const { + return graph_->block(); + } + private: friend class StaticRuntime; - - // Our nodes don't have their inputs & outputs initialized; don't - // let anybody but StaticRuntime and tests get them. - const std::vector& nodes() const { - return nodes_; - } + friend class BlockRunner; public: auto num_nodes() const { - return nodes_.size(); + return std::accumulate( + block_infos_.begin(), + block_infos_.end(), + 0, + [](size_t sum, const auto& block_and_info) { + auto& block_info = block_and_info.second; + return sum + block_info.num_nodes(); + }); } C10_NODISCARD Node* findNodeWithKindForTesting(const std::string& kind) const; - graph_node_list node_ptrs() const { - return graph_->nodes(); - } - - bool is_optimizable_container_type(const Node* n) const { - auto it = node_is_optimizable_container_type_.find(n); - return it != node_is_optimizable_container_type_.end(); - } - const c10::optional& schema() const { return schema_; } - const ValueGroup& value_group() const { - return value_group_; - } - - const FastSet& managed_tensor_values() const { - return managed_tensor_values_; - } - - const FastSet& managed_output_tensor_values() const { - return managed_output_tensor_values_; - } - - const FastSet& leaked_values() const { - return leaked_values_; - } - - const ManagedTensorRanges& managed_tensor_ranges() const { - return managed_tensor_ranges_; - } - bool first_input_is_self() const { return module_.has_value(); } - size_t inputs_offset() const { - return 0; - } - - size_t constants_offset() const { - return inputs_offset() + num_inputs(); - } - - size_t intermediate_values_offset() const { - return constants_offset() + num_constants(); - } - StaticRuntime& runtime(); + // See [Shared values array] + size_t value_buffer_size() const { + return value_buffer_size_; + } + private: + // Recursively prepares the BlockInfo array. + // - Populates `value_to_index` with the indices of each intermediate value + // - Returns the number of Value* processed, including sub-blocks. + size_t prepareBlockInfo( + Block* block, + const size_t start_idx, + FastMap& value_to_index); + + void prepareFunctionsAndConstants( + Block* block, + const AliasDb& alias_db, + FastMap& value_to_index); + + // Recurses on sub-blocks and populates the array of ProcessedNodes + // Returns (number of nodes processed, number of blocks processed) + size_t prepareProcessedNodes( + Block* block, + const FastMap& value_to_index, + const AliasDb& alias_db, + size_t node_idx = 0); + // Initialize various attributes that the memory planner will need. // To be called at the tail of the ctor. void prepareForMemoryPlanner(); @@ -383,15 +476,6 @@ class TORCH_API StaticModule { // Indices of graph outputs in the single values array. std::vector output_indices_; - ValueGroup value_group_; - - FastSet node_is_optimizable_container_type_; - - FastSet managed_tensor_values_{}; - FastSet managed_output_tensor_values_{}; - FastSet leaked_values_{}; - ManagedTensorRanges managed_tensor_ranges_{}; - size_t num_intermediate_values_ = 0; // Includes self if module_ != nullopt. @@ -399,16 +483,33 @@ class TORCH_API StaticModule { // argument. In this case, `self` isn't used in the graph, but the schema // includes it anyways to be consistent with the JIT interpreter. size_t num_inputs_; + // See `BlockInfo` definition. The blocks are stored in depth-first order. + FastMap block_infos_; + size_t value_buffer_size_ = 0; }; -class TORCH_API StaticRuntime { +// `BlockRunner` contains the core runtime logic. Each block runner +// corresponds to one block in the graph and has its own memory planner. +// `StaticRuntime` will initialize all `BlockRunner`s +// upon construction. Each block runner only directly executes nodes from its +// block. Special ops with sub-blocks like `prim::If` may have +// `BlockRunner`s stored in their `ProcessedNode`s; these +// sub-blocks get executed in the op's implementation. +// `StaticRuntime` stores a vector of IValues that all +// `BlockRunner`s share. This vector is used to store all +// constants, inputs, and intermediate tensors. +class TORCH_API BlockRunner { public: - explicit StaticRuntime(const StaticModule& sm); - StaticRuntime(StaticRuntime&&) = delete; - StaticRuntime& operator=(StaticRuntime&&) = delete; - ~StaticRuntime(); + BlockRunner( + const StaticModule& sm, + std::vector& values, + Block* block, + bool is_root_block = false); + BlockRunner(BlockRunner&&) noexcept; + BlockRunner& operator=(BlockRunner&&) = delete; + ~BlockRunner(); - C10_DISABLE_COPY_AND_ASSIGN(StaticRuntime); + C10_DISABLE_COPY_AND_ASSIGN(BlockRunner); using KeywordArgs = std::unordered_map; c10::IValue operator()( @@ -451,11 +552,16 @@ class TORCH_API StaticRuntime { // Input is readwrite IValue& Input(uint32_t i) { - DCHECK_LT(i, static_module_.num_inputs()); + DCHECK_LT(i, block_info_.num_inputs()); DCHECK_LT(i, values_.size()); - return values_[i]; + return values_[i + block_info_.block_inputs_idx()]; } + size_t init_sub_blocks( + const StaticModule& sm, + std::vector& values, + size_t block_idx); + // Output is readonly. The writing process happens inside ProcessedNodes C10_NODISCARD const IValue& Output(uint32_t i) const { DCHECK(i < outputs_.size()); @@ -475,7 +581,7 @@ class TORCH_API StaticRuntime { } graph_node_list node_ptrs() const { - return static_module_.node_ptrs(); + return block_info_.node_ptrs(); } const Graph& graph() const { @@ -486,11 +592,9 @@ class TORCH_API StaticRuntime { return planner_.get(); } - bool check_for_memory_leak(bool output_returned = true); - - bool is_optimizable_container_type(Node* n) const { - return static_module_.is_optimizable_container_type(n); - } + bool check_for_memory_leak( + bool output_returned = true, + bool recurse_on_sub_blocks = false); // WARNING: Deallocate managed output tensors. A client receiving Static // Runtime-managed Tensors needs to be very careful to call @@ -521,7 +625,8 @@ class TORCH_API StaticRuntime { // when destructed. class Deallocator { public: - explicit Deallocator(StaticRuntime& runtime) : runtime_(runtime) {} + explicit Deallocator(BlockRunner& block_runner) + : block_runner_(block_runner) {} Deallocator(Deallocator&&) = default; Deallocator(const Deallocator&) = default; @@ -537,7 +642,7 @@ class TORCH_API StaticRuntime { void cleanupImpl(); bool finished_ = false; - StaticRuntime& runtime_; + BlockRunner& block_runner_; }; template @@ -569,8 +674,8 @@ class TORCH_API StaticRuntime { // clean up owning refs of input IValues void clean_up_input_ivalues() noexcept { - for (const auto idx : c10::irange(static_module_.num_inputs())) { - values_[idx] = IValue(); + for (const auto idx : c10::irange(block_info_.num_inputs())) { + values_[idx + inputs_begin_] = IValue(); } } @@ -591,16 +696,29 @@ class TORCH_API StaticRuntime { const KeywordArgs& kwargs); const StaticModule& static_module_; + const BlockInfo& block_info_; + + const bool is_root_block_; // Cache this so we don't have to call static_module_.first_input_is_self() const bool first_input_is_self_; + // Index of the start of this blocks inputs in the shared values_ array. + const uint16_t inputs_begin_; + bool manage_output_tensors_enabled_ = false; std::unique_ptr planner_; - // first static_module_.num_inputs() slots are inputs, next - // static_module_.constants().size() slots are a copy of - // static_module_.constants(), rest are regular values in the - // graph. ProcessedNodes reference their inputs and outputs with + // [Shared values array] + // ProcessedNodes reference their inputs and outputs with // offsets into this array, which saves memory. - std::vector values_; + // All BlockRunners share the same array. The layout is as + // follows: + // [constants][block_0][block_1]...[block_N] + // Note that constants from all blocks are pooled together at the start. + // The block ordering is depth-first. + // Each block is further divided into inputs and intermediates: + // [block_i] = [inputs_i][intermediates_i] + // Each BlockRunner knows where its inputs start. Each ProcessedNode + // knows how to find the indices of its outputs/inputs in this array. + std::vector& values_; std::vector outputs_; std::vector nodes_; }; @@ -643,15 +761,55 @@ class TORCH_API ProcessedNode { ProcessedNode() = default; // ProcessedNodes are created within StaticModule and then // associated with a shared values array using set_values() when - // they are copied into a StaticRuntime. + // they are copied into a StaticRuntime. block_runners_ are also + // not initialized until StaticRuntime initialization; see + // BlockRunner's ctor. ProcessedNode( Node* n, ProcessedFunction* fn, ProcessedNodeInputs inputs, uint16_t outputs_offset); - ProcessedNode(const ProcessedNode&) = default; - ProcessedNode& operator=(const ProcessedNode&) = default; + ProcessedNode(const ProcessedNode& other) + : node_(other.node_), + fn_(other.fn_), + overlap_detected_(other.overlap_detected_), + inputs_(other.inputs_), + outputs_offset_(other.outputs_offset_), + num_outputs_(other.num_outputs_), + values_(other.values_), + // It doesn't really make sense to copy block runners, + // each processed node needs its own. This is OK to do + // since ProcessedNodes are copied from StaticModule right before + // the block runners are set up. + // TODO(T105178680): For this task, we should move + // block runners out of ProcessedNode. Then, we don't have to deal + // with this caveat. + block_runners_(nullptr) +#ifndef PYTORCH_DISABLE_PER_OP_PROFILING + , + op_name_(other.op_name_) +#endif + { + } + + ProcessedNode& operator=(const ProcessedNode& other) { + if (&other == this) { + return *this; + } + node_ = other.node_; + fn_ = other.fn_; + overlap_detected_ = other.overlap_detected_; + inputs_ = other.inputs_; + outputs_offset_ = other.outputs_offset_; + num_outputs_ = other.num_outputs_; + values_ = other.values_; + block_runners_ = nullptr; +#ifndef PYTORCH_DISABLE_PER_OP_PROFILING + op_name_ = other.op_name_; +#endif + return *this; + } // These should be noexcept, but some Android build is failing // saying the noexcept specification doesn't match the calculated @@ -732,11 +890,21 @@ class TORCH_API ProcessedNode { } C10_NODISCARD uint16_t output_ivalue_index(uint16_t i) const { + DCHECK(i < num_outputs_); return outputs_offset_ + i; } // used in debug mode bool verify_no_memory_overlap(bool force_check = false) const; + std::vector* block_runners() { + return block_runners_.get(); + } + + void set_block_runners( + std::unique_ptr> block_runners) { + block_runners_ = std::move(block_runners); + } + private: C10_NODISCARD bool verify_outputs_dont_overlap_each_other() const; @@ -749,10 +917,74 @@ class TORCH_API ProcessedNode { uint16_t outputs_offset_; uint16_t num_outputs_; IValue* values_ = nullptr; // unowned + // For control flow; processed nodes may have sub-blocks which can + // be executed by op implementations. + std::unique_ptr> block_runners_; #ifndef PYTORCH_DISABLE_PER_OP_PROFILING const char* op_name_; #endif }; +// `StaticRuntime` is the owner of the array of IValues (used for constants, +// inputs, and intermediate tensors) that all `BlockRunner`s share. +// Upon construction, it initializes all block runners. `operator()` simply +// forwards the inputs to the top-level block runner. Each `StaticRuntime` +// instance corresponds to one `StaticModule`. Multiple `StaticRuntime` +// instances can be created; this is useful for multi-threaded execution, since +// `operator()` is not thread-safe. +class TORCH_API StaticRuntime { + public: + explicit StaticRuntime(const StaticModule& sm); + + using KeywordArgs = std::unordered_map; + c10::IValue operator()( + const std::vector& args, + const KeywordArgs& kwargs = KeywordArgs()); + c10::IValue operator()( + std::vector&& args, + const KeywordArgs& kwargs = KeywordArgs()); + + bool check_for_memory_leak(bool output_returned = true); + bool checkOutputTensorMemoryLeaks(); + + void deallocateOutputTensors(); + bool isManagedOutputTensor(const IValue& ivalue) const; + void disableManageOutputTensors(); + + // Gets the top-level memory planner. Used for testing. + const MemoryPlanner* get_memory_planner() const; + + void benchmark( + const std::vector>& args_list, + const std::vector& kwargs_list, + const int warmup_runs, + const int main_runs, + bool print_per_node_time = false, + bool generate_ai_pep_output = false) { + block_->benchmark( + args_list, + kwargs_list, + warmup_runs, + main_runs, + print_per_node_time, + generate_ai_pep_output); + } + + using IndividualMetrics = BlockRunner::IndividualMetrics; + + IndividualMetrics benchmark_individual_ops( + const std::vector>& args_list, + const std::vector& kwargs_list, + const int warmup_runs, + const int main_runs) { + return block_->benchmark_individual_ops( + args_list, kwargs_list, warmup_runs, main_runs); + } + + private: + std::unique_ptr block_; + std::vector values_; +}; + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/runtime/static/memory_planner.cpp b/torch/csrc/jit/runtime/static/memory_planner.cpp index c44a70f6b49..b19bfe5f6f9 100644 --- a/torch/csrc/jit/runtime/static/memory_planner.cpp +++ b/torch/csrc/jit/runtime/static/memory_planner.cpp @@ -129,10 +129,10 @@ bool setIncludes(const FastSet& set, const Value* v) { } std::vector> assignStorageToOutputTensors( - StaticRuntime* runtime, + BlockRunner* block_runner, const FastSet& managed_output_tensor_values) { std::vector> managed_output_tensors; - for (auto& pnode : runtime->nodes()) { + for (auto& pnode : block_runner->nodes()) { for (const auto i : c10::irange(pnode.outputs().size())) { auto& ival = pnode.Output(i); const auto* val = pnode.node()->outputs()[i]; @@ -151,19 +151,20 @@ std::vector> assignStorageToOutputTensors( } // namespace MemoryPlanner::MemoryPlanner( - StaticRuntime* runtime, - const ValueGroup& value_group, - const FastSet& managed_tensor_values, - const FastSet& managed_output_tensor_values, - const FastSet& leaked_values, - const ManagedTensorRanges& ranges, + BlockRunner* block_runner, + const BlockInfo& block_info, bool enable_out_variant, bool manage_output_tensors, bool optimize_memory) { + const auto& managed_tensor_values = block_info.managed_tensor_values(); + const auto& managed_output_tensor_values = + block_info.managed_output_tensor_values(); + const auto& leaked_values = block_info.leaked_values(); + // collect unmanaged output ivalues FastSet unmanaged_ivalues; FastSet unmanaged_borrowed_ivalues; - for (ProcessedNode& pnode : runtime->nodes()) { + for (ProcessedNode& pnode : block_runner->nodes()) { const auto borrows_outputs = borrowsOutputs(pnode.node()->kind()); for (const auto i : c10::irange(pnode.outputs().size())) { const Value* out_v = pnode.node()->outputs()[i]; @@ -189,7 +190,7 @@ MemoryPlanner::MemoryPlanner( } } } - for (IValue* output : runtime->outputs()) { + for (IValue* output : block_runner->outputs()) { auto it = unmanaged_borrowed_ivalues.find(output); if (it != unmanaged_borrowed_ivalues.end()) { borrowed_ivalues_needing_incref_.push_back(output); @@ -213,10 +214,12 @@ MemoryPlanner::MemoryPlanner( if (enable_out_variant) { const auto tensor_value_to_tensor = - tensorValueToTensor(runtime->nodes(), managed_tensor_values); + tensorValueToTensor(block_runner->nodes(), managed_tensor_values); if (optimize_memory) { managed_tensors_ = assignStorageToManagedTensors( - runtime->node_ptrs(), ranges, tensor_value_to_tensor); + block_info.node_ptrs(), + block_info.managed_tensor_ranges(), + tensor_value_to_tensor); } else { for (auto& tensor : tensor_value_to_tensor) { managed_tensors_.emplace_back(tensor.second); @@ -225,8 +228,8 @@ MemoryPlanner::MemoryPlanner( } if (enable_out_variant && manage_output_tensors) { - managed_output_tensors_ = - assignStorageToOutputTensors(runtime, managed_output_tensor_values); + managed_output_tensors_ = assignStorageToOutputTensors( + block_runner, managed_output_tensor_values); } num_managed_tensors_ = 0; diff --git a/torch/csrc/jit/runtime/static/memory_planner.h b/torch/csrc/jit/runtime/static/memory_planner.h index d4159eb2d1c..5281de1e6e5 100644 --- a/torch/csrc/jit/runtime/static/memory_planner.h +++ b/torch/csrc/jit/runtime/static/memory_planner.h @@ -93,12 +93,8 @@ TORCH_API std::vector assignStorageToManagedTensors( class MemoryPlanner { public: explicit MemoryPlanner( - StaticRuntime* runtime, - const ValueGroup& value_group, - const FastSet& managed_tensor_values, - const FastSet& managed_output_tensor_values, - const FastSet& leaked_values, - const ManagedTensorRanges& ranges, + BlockRunner* block_runner, + const BlockInfo& block_info, bool enable_out_variant, bool manage_output_tensors, bool optimize_memory);