mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[SR] Add BlockRunner and handle sub-blocks (#69834)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69834 * Modify the `StaticModule` constructor to handle index initialization for sub-blocks. * Add a new class `StaticRuntimeBlockRunner`. This class is almost exactly like what we've been calling `StaticRuntime` up to this point, except that it does not own a `values_` array. All `StaticRuntimeBlockRunners` hold an unowned reference to a `values_` array owned by `StaticRuntime`. This is a useful abstraction for implementing control flow - it gives us a way for sub-blocks to look up values from surrounding scopes! ghstack-source-id: 148086245 Test Plan: `buck test caffe2/benchmarks/static_runtime/...` Reviewed By: d1jang Differential Revision: D33028039 fbshipit-source-id: 4f01417bad51a0cf09b1680a518308da647be1f6
This commit is contained in:
parent
bdcc5f5f47
commit
3a9feffd92
|
|
@ -106,7 +106,8 @@ TEST(StaticModule, ValueGroup) {
|
||||||
torch::jit::StaticModule sm(input_graph);
|
torch::jit::StaticModule sm(input_graph);
|
||||||
const Graph& graph = sm.graph();
|
const Graph& graph = sm.graph();
|
||||||
std::vector<const Node*> nodes(graph.nodes().begin(), graph.nodes().end());
|
std::vector<const Node*> nodes(graph.nodes().begin(), graph.nodes().end());
|
||||||
const auto& value_group = sm.value_group();
|
auto* root_block = sm.root_block();
|
||||||
|
const auto& value_group = sm.block_info(root_block).value_group();
|
||||||
|
|
||||||
std::vector<const Value*> expected_input_aliases{
|
std::vector<const Value*> expected_input_aliases{
|
||||||
graph.inputs()[0], graph.inputs()[1], nodes[0]->output()};
|
graph.inputs()[0], graph.inputs()[1], nodes[0]->output()};
|
||||||
|
|
@ -138,9 +139,11 @@ TEST(StaticModule, IsOptimizableContainerType_NonOptimizableInputs) {
|
||||||
|
|
||||||
auto sm = makeStaticModuleFromScript(src);
|
auto sm = makeStaticModuleFromScript(src);
|
||||||
const auto& graph = sm.graph();
|
const auto& graph = sm.graph();
|
||||||
|
auto* root_block = sm.root_block();
|
||||||
|
const auto& block_info = sm.block_info(root_block);
|
||||||
|
|
||||||
for (const Node* n : graph.nodes()) {
|
for (const Node* n : graph.nodes()) {
|
||||||
EXPECT_FALSE(sm.is_optimizable_container_type(n));
|
EXPECT_FALSE(block_info.node_is_optimizable_container_type(n));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -158,9 +161,11 @@ TEST(StaticModule, IsOptimizableContainerType_WrongType) {
|
||||||
|
|
||||||
auto sm = makeStaticModuleFromScript(src);
|
auto sm = makeStaticModuleFromScript(src);
|
||||||
const auto& graph = sm.graph();
|
const auto& graph = sm.graph();
|
||||||
|
auto* root_block = sm.root_block();
|
||||||
|
const auto& block_info = sm.block_info(root_block);
|
||||||
|
|
||||||
for (const Node* n : graph.nodes()) {
|
for (const Node* n : graph.nodes()) {
|
||||||
EXPECT_FALSE(sm.is_optimizable_container_type(n));
|
EXPECT_FALSE(block_info.node_is_optimizable_container_type(n));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -175,12 +180,14 @@ TEST(StaticModule, IsOptimizableContainerType_CanUseOutVariant) {
|
||||||
)JIT";
|
)JIT";
|
||||||
auto sm = makeStaticModuleFromScript(src);
|
auto sm = makeStaticModuleFromScript(src);
|
||||||
const auto& graph = sm.graph();
|
const auto& graph = sm.graph();
|
||||||
|
auto* root_block = sm.root_block();
|
||||||
|
const auto& block_info = sm.block_info(root_block);
|
||||||
|
|
||||||
for (const Node* n : graph.nodes()) {
|
for (const Node* n : graph.nodes()) {
|
||||||
if (n->kind() == c10::prim::ListConstruct) {
|
if (n->kind() == c10::prim::ListConstruct) {
|
||||||
EXPECT_TRUE(sm.is_optimizable_container_type(n));
|
EXPECT_TRUE(block_info.node_is_optimizable_container_type(n));
|
||||||
} else {
|
} else {
|
||||||
EXPECT_FALSE(sm.is_optimizable_container_type(n));
|
EXPECT_FALSE(block_info.node_is_optimizable_container_type(n));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1050,7 +1057,8 @@ TEST(ManagedTensorRanges, NoAliases) {
|
||||||
auto* z = vmap["z"];
|
auto* z = vmap["z"];
|
||||||
|
|
||||||
FastSet<const Value*> managed_tensors = {y, z};
|
FastSet<const Value*> managed_tensors = {y, z};
|
||||||
ManagedTensorRanges ranges(graph, managed_tensors);
|
AliasDb alias_db(graph);
|
||||||
|
auto ranges = ManagedTensorRanges(*graph->block(), alias_db, managed_tensors);
|
||||||
|
|
||||||
std::vector<Node*> nodes(
|
std::vector<Node*> nodes(
|
||||||
graph->block()->nodes().begin(), graph->block()->nodes().end());
|
graph->block()->nodes().begin(), graph->block()->nodes().end());
|
||||||
|
|
@ -1089,7 +1097,8 @@ TEST(ManagedTensorRanges, AliasExtendingLifetimes) {
|
||||||
auto* z2 = vmap["z2"];
|
auto* z2 = vmap["z2"];
|
||||||
|
|
||||||
FastSet<const Value*> managed_tensors = {y, z1, z2};
|
FastSet<const Value*> managed_tensors = {y, z1, z2};
|
||||||
ManagedTensorRanges ranges(graph, managed_tensors);
|
AliasDb alias_db(graph);
|
||||||
|
auto ranges = ManagedTensorRanges(*graph->block(), alias_db, managed_tensors);
|
||||||
|
|
||||||
std::vector<Node*> nodes(
|
std::vector<Node*> nodes(
|
||||||
graph->block()->nodes().begin(), graph->block()->nodes().end());
|
graph->block()->nodes().begin(), graph->block()->nodes().end());
|
||||||
|
|
@ -1135,7 +1144,8 @@ TEST(ManagedTensorRanges, LifetimeOverlap) {
|
||||||
auto* d = vmap["d"];
|
auto* d = vmap["d"];
|
||||||
auto* e = vmap["e"];
|
auto* e = vmap["e"];
|
||||||
|
|
||||||
ManagedTensorRanges ranges(graph, {b, c, d, e});
|
AliasDb alias_db(graph);
|
||||||
|
auto ranges = ManagedTensorRanges(*graph->block(), alias_db, {b, c, d, e});
|
||||||
const std::vector<std::pair<Value*, Value*>> overlapping_values{
|
const std::vector<std::pair<Value*, Value*>> overlapping_values{
|
||||||
{b, c}, {c, d}, {c, e}};
|
{b, c}, {c, d}, {c, e}};
|
||||||
|
|
||||||
|
|
@ -1169,7 +1179,8 @@ TEST(ManagedTensorRanges, OverlappingLifetimesContainers) {
|
||||||
auto* c = vmap["c"];
|
auto* c = vmap["c"];
|
||||||
auto* d = vmap["d"];
|
auto* d = vmap["d"];
|
||||||
|
|
||||||
ManagedTensorRanges ranges(graph, {b, c, d});
|
AliasDb alias_db(graph);
|
||||||
|
auto ranges = ManagedTensorRanges(*graph->block(), alias_db, {b, c, d});
|
||||||
|
|
||||||
EXPECT_TRUE(ranges.lifetimesOverlap(b, c));
|
EXPECT_TRUE(ranges.lifetimesOverlap(b, c));
|
||||||
EXPECT_TRUE(ranges.lifetimesOverlap(b, d));
|
EXPECT_TRUE(ranges.lifetimesOverlap(b, d));
|
||||||
|
|
@ -1189,7 +1200,8 @@ TEST(ManagedTensorRanges, OverlappingLifetimesOutputs) {
|
||||||
auto* b = vmap["b"];
|
auto* b = vmap["b"];
|
||||||
auto* output = vmap["output"];
|
auto* output = vmap["output"];
|
||||||
|
|
||||||
ManagedTensorRanges ranges(graph, {b, output});
|
AliasDb alias_db(graph);
|
||||||
|
auto ranges = ManagedTensorRanges(*graph->block(), alias_db, {b, output});
|
||||||
|
|
||||||
EXPECT_TRUE(ranges.lifetimesOverlap(b, output));
|
EXPECT_TRUE(ranges.lifetimesOverlap(b, output));
|
||||||
}
|
}
|
||||||
|
|
@ -1275,7 +1287,9 @@ void testAssignStorageToManagedTensors(
|
||||||
}
|
}
|
||||||
ASSERT_EQ(managed_tensor_values.size(), tensor_value_to_tensor.size());
|
ASSERT_EQ(managed_tensor_values.size(), tensor_value_to_tensor.size());
|
||||||
|
|
||||||
auto ranges = ManagedTensorRanges(graph, managed_tensor_values);
|
AliasDb alias_db(graph);
|
||||||
|
auto ranges =
|
||||||
|
ManagedTensorRanges(*graph->block(), alias_db, managed_tensor_values);
|
||||||
auto groups = assignStorageToManagedTensors(
|
auto groups = assignStorageToManagedTensors(
|
||||||
graph->block()->nodes(), ranges, tensor_value_to_tensor);
|
graph->block()->nodes(), ranges, tensor_value_to_tensor);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@
|
||||||
#include <torch/csrc/jit/passes/remove_mutation.h>
|
#include <torch/csrc/jit/passes/remove_mutation.h>
|
||||||
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
||||||
#include <torch/csrc/jit/passes/variadic_ops.h>
|
#include <torch/csrc/jit/passes/variadic_ops.h>
|
||||||
|
#include <torch/csrc/jit/runtime/graph_iterator.h>
|
||||||
#include <torch/csrc/jit/runtime/static/fusion.h>
|
#include <torch/csrc/jit/runtime/static/fusion.h>
|
||||||
#include <torch/csrc/jit/runtime/static/memory_planner.h>
|
#include <torch/csrc/jit/runtime/static/memory_planner.h>
|
||||||
#include <torch/csrc/jit/runtime/static/ops.h>
|
#include <torch/csrc/jit/runtime/static/ops.h>
|
||||||
|
|
@ -32,6 +33,7 @@
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
|
#include <limits>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
|
|
@ -58,8 +60,8 @@ bool isUnsupportedOp(const NodeKind& kind) {
|
||||||
return kind == aten::__is__ || kind == aten::__isnot__;
|
return kind == aten::__is__ || kind == aten::__isnot__;
|
||||||
}
|
}
|
||||||
|
|
||||||
// graph must be frozen or canEnableStaticRuntime would return false if there's
|
// graph must be frozen or canEnableStaticRuntime would return false
|
||||||
// any prim::CallMethod op left in the graph
|
// if there's any prim::CallMethod op left in the graph
|
||||||
bool canEnableStaticRuntime(const std::shared_ptr<torch::jit::Graph>& graph) {
|
bool canEnableStaticRuntime(const std::shared_ptr<torch::jit::Graph>& graph) {
|
||||||
// check for sub-blocks
|
// check for sub-blocks
|
||||||
bool can_support = true;
|
bool can_support = true;
|
||||||
|
|
@ -181,26 +183,20 @@ std::vector<Value*> valueVecFromFastSet(const FastSet<const Value*>& s) {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool mayContainAlias(AliasDb& db, const Value* a, const Value* b) {
|
bool mayContainAlias(const AliasDb& db, const Value* v1, const Value* v2) {
|
||||||
|
// AliasDb is not const-correct here, so we have to const_cast
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||||
return db.mayContainAlias(const_cast<Value*>(a), const_cast<Value*>(b));
|
return db.mayContainAlias(const_cast<Value*>(v1), const_cast<Value*>(v2));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool mayContainAlias(
|
bool mayContainAlias(
|
||||||
AliasDb& db,
|
const AliasDb& db,
|
||||||
const Value* a,
|
const Value* a,
|
||||||
const FastSet<const Value*>& b) {
|
const FastSet<const Value*>& b) {
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||||
return db.mayContainAlias(const_cast<Value*>(a), valueVecFromFastSet(b));
|
return db.mayContainAlias(const_cast<Value*>(a), valueVecFromFastSet(b));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool mayContainAlias(
|
|
||||||
AliasDb& db,
|
|
||||||
const FastSet<const Value*>& a,
|
|
||||||
const FastSet<const Value*>& b) {
|
|
||||||
return db.mayContainAlias(valueVecFromFastSet(a), valueVecFromFastSet(b));
|
|
||||||
}
|
|
||||||
|
|
||||||
void PrepareGraphForStaticModule(
|
void PrepareGraphForStaticModule(
|
||||||
std::shared_ptr<torch::jit::Graph> graph,
|
std::shared_ptr<torch::jit::Graph> graph,
|
||||||
const StaticModuleOptions& opts,
|
const StaticModuleOptions& opts,
|
||||||
|
|
@ -248,23 +244,21 @@ std::pair<std::shared_ptr<Graph>, c10::optional<Module>> PrepareForStaticModule(
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void ValueGroup::init(
|
void ValueGroup::init(const Block& block, const AliasDb& db) {
|
||||||
const std::shared_ptr<torch::jit::Graph>& graph,
|
|
||||||
AliasDb& db) {
|
|
||||||
external_aliases_.clear();
|
external_aliases_.clear();
|
||||||
output_aliases_.clear();
|
output_aliases_.clear();
|
||||||
// Build `external_aliases` as we look through nodes forwardly from
|
// Build `external_aliases` as we look through nodes forwardly from
|
||||||
// the graph's inputs and add aliases of the inputs being created by the
|
// the graph's inputs and add aliases of the inputs being created by the
|
||||||
// nodes.
|
// nodes.
|
||||||
external_aliases_.insert(graph->inputs().begin(), graph->inputs().end());
|
external_aliases_.insert(block.inputs().begin(), block.inputs().end());
|
||||||
for (const auto* node : graph->nodes()) {
|
for (const auto* node : block.nodes()) {
|
||||||
if (node->kind() == prim::Constant) {
|
if (node->kind() == prim::Constant) {
|
||||||
for (const auto* output : node->outputs()) {
|
for (const auto* output : node->outputs()) {
|
||||||
external_aliases_.insert(output);
|
external_aliases_.insert(output);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (const auto* node : graph->nodes()) {
|
for (const auto* node : block.nodes()) {
|
||||||
if (node->kind() == prim::Constant) {
|
if (node->kind() == prim::Constant) {
|
||||||
// Constants are already in `external_aliases`.
|
// Constants are already in `external_aliases`.
|
||||||
continue;
|
continue;
|
||||||
|
|
@ -278,8 +272,8 @@ void ValueGroup::init(
|
||||||
|
|
||||||
// Build `output_aliases` as we look through nodes reversely so that we can
|
// Build `output_aliases` as we look through nodes reversely so that we can
|
||||||
// start from the output values, and follow the flows backwardly from there.
|
// start from the output values, and follow the flows backwardly from there.
|
||||||
output_aliases_.insert(graph->outputs().begin(), graph->outputs().end());
|
output_aliases_.insert(block.outputs().begin(), block.outputs().end());
|
||||||
for (const auto* node : graph->nodes().reverse()) {
|
for (const auto* node : block.nodes().reverse()) {
|
||||||
if (node->kind() == prim::Constant) {
|
if (node->kind() == prim::Constant) {
|
||||||
// Constants cannot create any aliases.
|
// Constants cannot create any aliases.
|
||||||
continue;
|
continue;
|
||||||
|
|
@ -317,12 +311,6 @@ bool containTensorsOnly(at::ArrayRef<Value*> values) {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
bool mayContainAlias(const Value* v1, const Value* v2, const AliasDb& db) {
|
|
||||||
// AliasDb is not const-correct here, so we have to const_cast
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
|
||||||
return db.mayContainAlias(const_cast<Value*>(v1), const_cast<Value*>(v2));
|
|
||||||
}
|
|
||||||
|
|
||||||
bool isPureFunction(const Node* node) {
|
bool isPureFunction(const Node* node) {
|
||||||
auto* schema = node->maybeSchema();
|
auto* schema = node->maybeSchema();
|
||||||
return schema &&
|
return schema &&
|
||||||
|
|
@ -332,12 +320,12 @@ bool isPureFunction(const Node* node) {
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
ManagedTensorRanges::ManagedTensorRanges(
|
ManagedTensorRanges::ManagedTensorRanges(
|
||||||
const std::shared_ptr<Graph>& graph,
|
Block& block,
|
||||||
|
const AliasDb& alias_db,
|
||||||
const FastSet<const Value*>& managed_tensor_values) {
|
const FastSet<const Value*>& managed_tensor_values) {
|
||||||
AliasDb alias_db(graph);
|
const std::vector<Node*> nodes(block.nodes().begin(), block.nodes().end());
|
||||||
const std::vector<Node*> nodes(graph->nodes().begin(), graph->nodes().end());
|
|
||||||
const FastSet<const Value*> graph_inputs(
|
const FastSet<const Value*> graph_inputs(
|
||||||
graph->inputs().begin(), graph->inputs().end());
|
block.inputs().begin(), block.inputs().end());
|
||||||
|
|
||||||
auto isUntrackedValue = [&alias_db, &graph_inputs](const Value* value) {
|
auto isUntrackedValue = [&alias_db, &graph_inputs](const Value* value) {
|
||||||
return !alias_db.isMutableType(value) ||
|
return !alias_db.isMutableType(value) ||
|
||||||
|
|
@ -363,7 +351,7 @@ ManagedTensorRanges::ManagedTensorRanges(
|
||||||
value_lifetimes_.emplace(output, Lifetime(i, i));
|
value_lifetimes_.emplace(output, Lifetime(i, i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (auto* graph_output : graph->outputs()) {
|
for (auto* graph_output : block.outputs()) {
|
||||||
auto* lifetime = getLifetime(graph_output);
|
auto* lifetime = getLifetime(graph_output);
|
||||||
if (!lifetime) {
|
if (!lifetime) {
|
||||||
DCHECK(isUntrackedValue(graph_output));
|
DCHECK(isUntrackedValue(graph_output));
|
||||||
|
|
@ -376,7 +364,7 @@ ManagedTensorRanges::ManagedTensorRanges(
|
||||||
// has an input and output that may alias each other, set the input's
|
// has an input and output that may alias each other, set the input's
|
||||||
// lifetime end to max(input.lifetime_end, output.lifetime_end). Iterate
|
// lifetime end to max(input.lifetime_end, output.lifetime_end). Iterate
|
||||||
// backwards to handle chains of aliases.
|
// backwards to handle chains of aliases.
|
||||||
for (const auto* node : graph->nodes().reverse()) {
|
for (const auto* node : block.nodes().reverse()) {
|
||||||
if (isPureFunction(node)) {
|
if (isPureFunction(node)) {
|
||||||
// If the node is a pure function, it doesn't create any aliases,
|
// If the node is a pure function, it doesn't create any aliases,
|
||||||
// so we can safely skip it.
|
// so we can safely skip it.
|
||||||
|
|
@ -389,7 +377,7 @@ ManagedTensorRanges::ManagedTensorRanges(
|
||||||
auto* input_lifetime = getLifetime(input);
|
auto* input_lifetime = getLifetime(input);
|
||||||
DCHECK(input_lifetime != nullptr);
|
DCHECK(input_lifetime != nullptr);
|
||||||
for (auto* output : outputs) {
|
for (auto* output : outputs) {
|
||||||
if (mayContainAlias(input, output, alias_db)) {
|
if (mayContainAlias(alias_db, input, output)) {
|
||||||
auto* output_lifetime = getLifetime(output);
|
auto* output_lifetime = getLifetime(output);
|
||||||
DCHECK(output_lifetime != nullptr);
|
DCHECK(output_lifetime != nullptr);
|
||||||
input_lifetime->end =
|
input_lifetime->end =
|
||||||
|
|
@ -404,7 +392,7 @@ ManagedTensorRanges::ManagedTensorRanges(
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||||
Node* freeing_node;
|
Node* freeing_node;
|
||||||
if (lifetime->end == num_nodes) {
|
if (lifetime->end == num_nodes) {
|
||||||
freeing_node = graph->return_node();
|
freeing_node = block.return_node();
|
||||||
} else {
|
} else {
|
||||||
freeing_node = nodes[lifetime->end];
|
freeing_node = nodes[lifetime->end];
|
||||||
}
|
}
|
||||||
|
|
@ -519,15 +507,6 @@ StaticModule::StaticModule(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// map Value* to its SSA definition IR
|
|
||||||
FastMap<Value*, DefInfo> value_to_ssa_def;
|
|
||||||
|
|
||||||
// N inputs map to the first N entries in storage
|
|
||||||
for (const auto i : c10::irange(graph_->inputs().size())) {
|
|
||||||
Value* input = graph_->inputs()[i];
|
|
||||||
value_to_ssa_def[input] = std::make_pair(INPUT_VALUE, i);
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
{
|
||||||
size_t nodes_size = 0, constants_size = 0;
|
size_t nodes_size = 0, constants_size = 0;
|
||||||
for (Node* node : graph_->nodes()) {
|
for (Node* node : graph_->nodes()) {
|
||||||
|
|
@ -536,7 +515,6 @@ StaticModule::StaticModule(
|
||||||
|
|
||||||
constants_.reserve(constants_size);
|
constants_.reserve(constants_size);
|
||||||
functions_.reserve(nodes_size);
|
functions_.reserve(nodes_size);
|
||||||
nodes_.reserve(nodes_size);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create ProcessedFunction instances first to freeze their addresses to pass
|
// Create ProcessedFunction instances first to freeze their addresses to pass
|
||||||
|
|
@ -544,13 +522,89 @@ StaticModule::StaticModule(
|
||||||
AliasDb alias_db(graph_, /*isFrozen=*/false);
|
AliasDb alias_db(graph_, /*isFrozen=*/false);
|
||||||
GRAPH_DEBUG("AliasDb: ", alias_db.toString());
|
GRAPH_DEBUG("AliasDb: ", alias_db.toString());
|
||||||
|
|
||||||
// Construct constant and function nodes
|
// Maps each Value* in the graph to its index in the values_ array that will
|
||||||
for (Node* node : graph_->nodes()) {
|
// eventually be created by StaticRuntime.
|
||||||
|
FastMap<const Value*, uint32_t> value_to_index;
|
||||||
|
prepareFunctionsAndConstants(graph_->block(), alias_db, value_to_index);
|
||||||
|
|
||||||
|
const auto constants_index_offset = 0;
|
||||||
|
const auto values_index_offset = constants_index_offset + constants().size();
|
||||||
|
value_buffer_size_ = values_index_offset;
|
||||||
|
|
||||||
|
value_buffer_size_ +=
|
||||||
|
prepareBlockInfo(graph_->block(), values_index_offset, value_to_index);
|
||||||
|
|
||||||
|
prepareProcessedNodes(graph_->block(), value_to_index, alias_db);
|
||||||
|
|
||||||
|
for (auto& block_and_info : block_infos_) {
|
||||||
|
auto& block_info = block_and_info.second;
|
||||||
|
block_info.prepare_for_memory_planner(alias_db, opts);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t StaticModule::prepareBlockInfo(
|
||||||
|
Block* block,
|
||||||
|
const size_t start_idx,
|
||||||
|
FastMap<const Value*, uint32_t>& value_to_index) {
|
||||||
|
block_infos_.emplace(block, BlockInfo(start_idx, *block));
|
||||||
|
|
||||||
|
const auto num_inputs = block->inputs().size();
|
||||||
|
for (const auto i : c10::irange(num_inputs)) {
|
||||||
|
value_to_index.emplace(block->inputs()[i], start_idx + i);
|
||||||
|
}
|
||||||
|
auto cur_idx = start_idx + num_inputs;
|
||||||
|
|
||||||
|
for (auto* node : block->nodes()) {
|
||||||
|
for (auto* sub_block : node->blocks()) {
|
||||||
|
cur_idx += prepareBlockInfo(sub_block, cur_idx, value_to_index);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (node->kind() == prim::Constant) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_CHECK(
|
||||||
|
cur_idx < (1 << 16),
|
||||||
|
"outputs offset in values table",
|
||||||
|
cur_idx,
|
||||||
|
" would overflow 2-byte index storage");
|
||||||
|
|
||||||
|
const auto num_outputs = node->outputs().size();
|
||||||
|
for (const auto i : c10::irange(num_outputs)) {
|
||||||
|
value_to_index.emplace(node->outputs()[i], cur_idx + i);
|
||||||
|
}
|
||||||
|
cur_idx += num_outputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<uint16_t> output_indices;
|
||||||
|
output_indices.reserve(block->outputs().size());
|
||||||
|
for (auto* output : block->outputs()) {
|
||||||
|
const auto output_idx = value_to_index.at(output);
|
||||||
|
TORCH_CHECK(
|
||||||
|
output_idx < (1 << 16),
|
||||||
|
"outputs offset in values table",
|
||||||
|
output_idx,
|
||||||
|
" would overflow 2-byte index storage");
|
||||||
|
output_indices.push_back(output_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
block_infos_.at(block).set_output_indices(std::move(output_indices));
|
||||||
|
return cur_idx - start_idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
void StaticModule::prepareFunctionsAndConstants(
|
||||||
|
Block* block,
|
||||||
|
const AliasDb& alias_db,
|
||||||
|
FastMap<const Value*, uint32_t>& value_to_index) {
|
||||||
|
for (auto* node : block->nodes()) {
|
||||||
|
for (auto* sub_block : node->blocks()) {
|
||||||
|
prepareFunctionsAndConstants(sub_block, alias_db, value_to_index);
|
||||||
|
}
|
||||||
|
|
||||||
if (node->kind() == prim::Constant) {
|
if (node->kind() == prim::Constant) {
|
||||||
auto* v = node->output();
|
auto* v = node->output();
|
||||||
TORCH_CHECK(v->type()->kind() != FunctionType::Kind);
|
TORCH_CHECK(v->type()->kind() != FunctionType::Kind);
|
||||||
// construct SSA definition for constant nodes
|
value_to_index.emplace(v, constants_.size());
|
||||||
value_to_ssa_def[v] = std::make_pair(CONSTANT_VALUE, constants_.size());
|
|
||||||
constants_.emplace_back(toIValue(v).value());
|
constants_.emplace_back(toIValue(v).value());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
@ -561,66 +615,34 @@ StaticModule::StaticModule(
|
||||||
containTensorsOnly(node->outputs());
|
containTensorsOnly(node->outputs());
|
||||||
// new ProcessedFunction
|
// new ProcessedFunction
|
||||||
functions_.emplace_back(
|
functions_.emplace_back(
|
||||||
node, opts.enable_out_variant, check_outputs_for_overlap);
|
node, opts_.enable_out_variant, check_outputs_for_overlap);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// construct SSA definition for non-constant nodes
|
size_t StaticModule::prepareProcessedNodes(
|
||||||
int node_idx = 0;
|
Block* block,
|
||||||
|
const FastMap<const Value*, uint32_t>& value_to_index,
|
||||||
|
const AliasDb& alias_db,
|
||||||
|
size_t node_idx) {
|
||||||
|
const auto node_start = node_idx;
|
||||||
|
|
||||||
|
auto& block_info = block_infos_.at(block);
|
||||||
|
std::vector<ProcessedNode> nodes;
|
||||||
FastMap<Node*, bool> node_has_out_variant;
|
FastMap<Node*, bool> node_has_out_variant;
|
||||||
|
|
||||||
const auto inputs_index_offset = inputs_offset();
|
for (auto* node : block->nodes()) {
|
||||||
const auto constants_index_offset = constants_offset();
|
|
||||||
const auto values_index_offset = intermediate_values_offset();
|
|
||||||
|
|
||||||
// Map node_idx to index offset in values_. Can't reserve space
|
|
||||||
// because we don't know how many non-constant nodes there are yet.
|
|
||||||
std::vector<uint32_t> node_output_idx_map;
|
|
||||||
uint32_t node_outputs_seen_so_far = 0;
|
|
||||||
for (Node* node : graph_->nodes()) {
|
|
||||||
if (node->kind() == prim::Constant) {
|
if (node->kind() == prim::Constant) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// Assign memory for the outputs
|
|
||||||
const auto outputs_offset_for_node =
|
|
||||||
node_outputs_seen_so_far + values_index_offset;
|
|
||||||
TORCH_CHECK(
|
|
||||||
outputs_offset_for_node < (1 << 16),
|
|
||||||
"outputs offset in values table",
|
|
||||||
outputs_offset_for_node,
|
|
||||||
" would overflow 2-byte index storage");
|
|
||||||
node_output_idx_map.push_back(outputs_offset_for_node);
|
|
||||||
node_outputs_seen_so_far += node->outputs().size();
|
|
||||||
}
|
|
||||||
|
|
||||||
for (Node* node : graph_->nodes()) {
|
for (auto* sub_block : node->blocks()) {
|
||||||
if (node->kind() == prim::Constant) {
|
node_idx +=
|
||||||
continue;
|
prepareProcessedNodes(sub_block, value_to_index, alias_db, node_idx);
|
||||||
}
|
}
|
||||||
ProcessedNodeInputs input_indices(node->inputs().size());
|
ProcessedNodeInputs input_indices(node->inputs().size());
|
||||||
std::vector<DefInfo> input_ssa_defs;
|
|
||||||
for (const auto input_idx : c10::irange(node->inputs().size())) {
|
for (const auto input_idx : c10::irange(node->inputs().size())) {
|
||||||
Value* const input = node->inputs()[input_idx];
|
auto* input = node->inputs()[input_idx];
|
||||||
int inner_node_idx = 0;
|
auto input_ivalue_idx = value_to_index.at(input);
|
||||||
int out_idx = 0;
|
|
||||||
std::tie(inner_node_idx, out_idx) = value_to_ssa_def.at(input);
|
|
||||||
unsigned int input_ivalue_idx = 0;
|
|
||||||
if (inner_node_idx == StaticModule::INPUT_VALUE) {
|
|
||||||
input_ivalue_idx = out_idx + inputs_index_offset;
|
|
||||||
} else if (inner_node_idx == StaticModule::CONSTANT_VALUE) {
|
|
||||||
input_ivalue_idx = out_idx + constants_index_offset;
|
|
||||||
} else {
|
|
||||||
DCHECK_GE(inner_node_idx, 0);
|
|
||||||
const auto global_value_idx =
|
|
||||||
node_output_idx_map[inner_node_idx] + out_idx;
|
|
||||||
if (inner_node_idx < node_output_idx_map.size() - 1) {
|
|
||||||
DCHECK_LT(global_value_idx, node_output_idx_map[inner_node_idx + 1]);
|
|
||||||
} else {
|
|
||||||
DCHECK_LT(
|
|
||||||
global_value_idx,
|
|
||||||
constants_index_offset + node_outputs_seen_so_far);
|
|
||||||
}
|
|
||||||
input_ivalue_idx = global_value_idx;
|
|
||||||
}
|
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
input_ivalue_idx < (1 << 16),
|
input_ivalue_idx < (1 << 16),
|
||||||
"input index in values table ",
|
"input index in values table ",
|
||||||
|
|
@ -630,72 +652,48 @@ StaticModule::StaticModule(
|
||||||
}
|
}
|
||||||
|
|
||||||
ProcessedFunction* fn = &functions_[node_idx];
|
ProcessedFunction* fn = &functions_[node_idx];
|
||||||
|
|
||||||
// create a new ProcessedNode
|
// create a new ProcessedNode
|
||||||
// see [Check and correct bad schema alias info at runtime]
|
const auto node_output_idx = node->outputs().empty()
|
||||||
bool check_outputs_for_overlap =
|
// The index is unused if there are no outputs, so just create a
|
||||||
!alias_db.mayContainAlias(node->inputs(), node->outputs()) &&
|
// placeholder value.
|
||||||
containTensorsOnly(node->outputs());
|
? std::numeric_limits<uint16_t>::max()
|
||||||
nodes_.emplace_back(
|
: value_to_index.at(node->output(0));
|
||||||
node, fn, std::move(input_indices), node_output_idx_map[node_idx]);
|
nodes.emplace_back(node, fn, std::move(input_indices), node_output_idx);
|
||||||
|
|
||||||
node_has_out_variant.emplace(node, nodes_.back().has_out_variant());
|
node_has_out_variant.emplace(node, nodes.back().has_out_variant());
|
||||||
for (const auto i : c10::irange(node->outputs().size())) {
|
++node_idx;
|
||||||
value_to_ssa_def[node->outputs()[i]] = std::make_pair(node_idx, i);
|
|
||||||
}
|
|
||||||
node_idx++;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
num_intermediate_values_ = std::accumulate(
|
block_info.set_nodes(std::move(nodes), node_has_out_variant);
|
||||||
nodes_.begin(),
|
block_info.init_value_group(alias_db);
|
||||||
nodes_.end(),
|
|
||||||
0,
|
|
||||||
[](uint32_t sum, const ProcessedNode& pnode) {
|
|
||||||
return sum + pnode.num_outputs();
|
|
||||||
});
|
|
||||||
|
|
||||||
for (auto& pnode : nodes_) {
|
return node_idx - node_start;
|
||||||
if (pnode.num_outputs() == 1 &&
|
|
||||||
isOptimizableContainerType(pnode.node(), node_has_out_variant)) {
|
|
||||||
node_is_optimizable_container_type_.emplace(pnode.node());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
output_indices_.reserve(graph_->outputs().size());
|
|
||||||
for (auto output : graph_->outputs()) {
|
|
||||||
int node_idx = 0;
|
|
||||||
int out_idx = 0;
|
|
||||||
std::tie(node_idx, out_idx) = value_to_ssa_def[output];
|
|
||||||
uint32_t output_index = 0;
|
|
||||||
if (node_idx == StaticModule::INPUT_VALUE) {
|
|
||||||
output_index = out_idx + inputs_index_offset;
|
|
||||||
} else if (node_idx == StaticModule::CONSTANT_VALUE) {
|
|
||||||
output_index = constants_index_offset + out_idx;
|
|
||||||
} else {
|
|
||||||
output_index = nodes_[node_idx].output_ivalue_index(out_idx);
|
|
||||||
}
|
|
||||||
TORCH_CHECK(
|
|
||||||
output_index < (1 << 16),
|
|
||||||
"output index ",
|
|
||||||
output_index,
|
|
||||||
" would overflow 2-byte index storage");
|
|
||||||
output_indices_.emplace_back(output_index);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare for memory planning
|
|
||||||
value_group_.init(graph_, alias_db);
|
|
||||||
GRAPH_DEBUG(value_group_.toString());
|
|
||||||
|
|
||||||
prepareForMemoryPlanner();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void StaticModule::prepareForMemoryPlanner() {
|
void BlockInfo::set_nodes(
|
||||||
if (!opts_.enable_out_variant) {
|
std::vector<ProcessedNode> nodes,
|
||||||
|
const FastMap<Node*, bool>& node_has_out_variant) {
|
||||||
|
nodes_ = std::move(nodes);
|
||||||
|
|
||||||
|
for (auto& node : nodes_) {
|
||||||
|
if (node.num_outputs() == 1 &&
|
||||||
|
isOptimizableContainerType(node.node(), node_has_out_variant)) {
|
||||||
|
node_is_optimizable_container_type_.emplace(node.node());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void BlockInfo::prepare_for_memory_planner(
|
||||||
|
const AliasDb& alias_db,
|
||||||
|
const StaticModuleOptions& opts) {
|
||||||
|
if (!opts.enable_out_variant) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Never manage graph outputs so that we can do std::move(output_ivalue).
|
// Never manage graph outputs so that we can do std::move(output_ivalue).
|
||||||
// This does not affect performance if the graph returns a collection object.
|
// This does not affect performance if the graph returns a collection object.
|
||||||
FastSet<const Value*> graph_output_values(
|
FastSet<const Value*> graph_output_values(
|
||||||
graph_->outputs().begin(), graph_->outputs().end());
|
block_.outputs().begin(), block_.outputs().end());
|
||||||
|
|
||||||
// collect register indices of outputs of ops with out variant
|
// collect register indices of outputs of ops with out variant
|
||||||
for (ProcessedNode& pnode : nodes_) {
|
for (ProcessedNode& pnode : nodes_) {
|
||||||
|
|
@ -707,7 +705,7 @@ void StaticModule::prepareForMemoryPlanner() {
|
||||||
const Value* out_v = outputs[i];
|
const Value* out_v = outputs[i];
|
||||||
// Types are stored in the underlying TorchScript IR
|
// Types are stored in the underlying TorchScript IR
|
||||||
bool is_tensor_type = out_v->type()->castRaw<TensorType>();
|
bool is_tensor_type = out_v->type()->castRaw<TensorType>();
|
||||||
if (opts_.manage_output_tensors && is_tensor_type &&
|
if (opts.manage_output_tensors && is_tensor_type &&
|
||||||
graph_output_values.find(out_v) == graph_output_values.end() &&
|
graph_output_values.find(out_v) == graph_output_values.end() &&
|
||||||
value_group_.isOutputAlias(out_v)) {
|
value_group_.isOutputAlias(out_v)) {
|
||||||
managed_output_tensor_values_.insert(out_v);
|
managed_output_tensor_values_.insert(out_v);
|
||||||
|
|
@ -718,7 +716,7 @@ void StaticModule::prepareForMemoryPlanner() {
|
||||||
}
|
}
|
||||||
if (is_tensor_type) {
|
if (is_tensor_type) {
|
||||||
managed_tensor_values_.insert(out_v);
|
managed_tensor_values_.insert(out_v);
|
||||||
} else if (is_optimizable_container_type(pnode.node())) {
|
} else if (node_is_optimizable_container_type(pnode.node())) {
|
||||||
// We "leak" certain container types because their allocations
|
// We "leak" certain container types because their allocations
|
||||||
// take a long time
|
// take a long time
|
||||||
leaked_values_.insert(out_v);
|
leaked_values_.insert(out_v);
|
||||||
|
|
@ -726,7 +724,7 @@ void StaticModule::prepareForMemoryPlanner() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const Value* output : graph_->outputs()) {
|
for (const Value* output : block_.outputs()) {
|
||||||
managed_tensor_values_.erase(output);
|
managed_tensor_values_.erase(output);
|
||||||
}
|
}
|
||||||
GRAPH_DEBUG("managed_tensor_values: ", dumpValueSet(managed_tensor_values_));
|
GRAPH_DEBUG("managed_tensor_values: ", dumpValueSet(managed_tensor_values_));
|
||||||
|
|
@ -734,7 +732,8 @@ void StaticModule::prepareForMemoryPlanner() {
|
||||||
"managed_output_tensor_values_: ",
|
"managed_output_tensor_values_: ",
|
||||||
dumpValueSet(managed_output_tensor_values_));
|
dumpValueSet(managed_output_tensor_values_));
|
||||||
|
|
||||||
managed_tensor_ranges_ = ManagedTensorRanges(graph_, managed_tensor_values_);
|
managed_tensor_ranges_ =
|
||||||
|
ManagedTensorRanges(block_, alias_db, managed_tensor_values_);
|
||||||
}
|
}
|
||||||
|
|
||||||
const StaticModuleOptions& StaticModule::opts() const {
|
const StaticModuleOptions& StaticModule::opts() const {
|
||||||
|
|
@ -757,9 +756,12 @@ StaticRuntime& StaticModule::runtime() {
|
||||||
}
|
}
|
||||||
|
|
||||||
Node* StaticModule::findNodeWithKindForTesting(const std::string& kind) const {
|
Node* StaticModule::findNodeWithKindForTesting(const std::string& kind) const {
|
||||||
for (auto& pnode : nodes()) {
|
for (auto& block_and_info : block_infos_) {
|
||||||
if (pnode.node()->kind().toQualString() == kind) {
|
auto& block_info = block_and_info.second;
|
||||||
return pnode.node();
|
for (auto& pnode : block_info.nodes()) {
|
||||||
|
if (pnode.node()->kind().toQualString() == kind) {
|
||||||
|
return pnode.node();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
@ -777,41 +779,64 @@ c10::IValue StaticModule::operator()(
|
||||||
return runtime()(std::move(args), kwargs);
|
return runtime()(std::move(args), kwargs);
|
||||||
}
|
}
|
||||||
|
|
||||||
StaticRuntime::StaticRuntime(const StaticModule& sm)
|
BlockRunner::BlockRunner(
|
||||||
|
const StaticModule& sm,
|
||||||
|
std::vector<IValue>& values,
|
||||||
|
Block* block,
|
||||||
|
bool is_root_block)
|
||||||
: static_module_(sm),
|
: static_module_(sm),
|
||||||
first_input_is_self_(static_module_.first_input_is_self()),
|
block_info_(static_module_.block_info(block)),
|
||||||
manage_output_tensors_enabled_(sm.opts().manage_output_tensors),
|
is_root_block_(is_root_block),
|
||||||
nodes_(sm.nodes()) {
|
first_input_is_self_(
|
||||||
values_.resize(sm.total_num_values());
|
is_root_block_ && static_module_.first_input_is_self()),
|
||||||
const auto constants_index_offset = sm.constants_offset();
|
inputs_begin_(block_info_.block_inputs_idx()),
|
||||||
const auto constants_begin_it = values_.begin() + constants_index_offset;
|
// TODO(T108633124): Turn on manage output tensors for sub-blocks.
|
||||||
const auto constants_end_it = constants_begin_it + sm.constants().size();
|
manage_output_tensors_enabled_(
|
||||||
std::copy(sm.constants().begin(), sm.constants().end(), constants_begin_it);
|
is_root_block_ && sm.opts().manage_output_tensors),
|
||||||
|
values_(values),
|
||||||
for (const auto idx : c10::irange(sm.nodes().size())) {
|
nodes_(block_info_.nodes()) {
|
||||||
auto& n = nodes_[idx];
|
for (auto& n : nodes_) {
|
||||||
n.set_values(values_.data());
|
n.set_values(values_.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: can we convert outputs_ to store indices?
|
for (auto index : block_info_.block_output_indices()) {
|
||||||
for (auto index : sm.output_indices()) {
|
|
||||||
outputs_.emplace_back(&values_[index]);
|
outputs_.emplace_back(&values_[index]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (auto& pnode : nodes_) {
|
||||||
|
auto* node = pnode.node();
|
||||||
|
auto blocks = node->blocks();
|
||||||
|
const auto num_blocks = blocks.size();
|
||||||
|
if (num_blocks == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
DCHECK(node->kind() == prim::If || node->kind() == prim::Loop);
|
||||||
|
auto block_runners = std::make_unique<std::vector<BlockRunner>>();
|
||||||
|
block_runners->reserve(num_blocks);
|
||||||
|
|
||||||
|
for (auto* b : blocks) {
|
||||||
|
block_runners->emplace_back(sm, values, b);
|
||||||
|
}
|
||||||
|
pnode.set_block_runners(std::move(block_runners));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
StaticRuntime::~StaticRuntime() = default;
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||||
|
BlockRunner::BlockRunner(BlockRunner&&) noexcept = default;
|
||||||
|
|
||||||
void StaticRuntime::set_arg(const size_t idx, std::vector<IValue>&& args) {
|
BlockRunner::~BlockRunner() = default;
|
||||||
|
|
||||||
|
void BlockRunner::set_arg(const size_t idx, std::vector<IValue>&& args) {
|
||||||
DCHECK(idx < args.size());
|
DCHECK(idx < args.size());
|
||||||
Input(idx + first_input_is_self_) = std::move(args[idx]);
|
Input(idx + first_input_is_self_) = std::move(args[idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
void StaticRuntime::set_arg(const size_t idx, const std::vector<IValue>& args) {
|
void BlockRunner::set_arg(const size_t idx, const std::vector<IValue>& args) {
|
||||||
DCHECK(idx < args.size());
|
DCHECK(idx < args.size());
|
||||||
Input(idx + first_input_is_self_) = args[idx];
|
Input(idx + first_input_is_self_) = args[idx];
|
||||||
}
|
}
|
||||||
|
|
||||||
void StaticRuntime::set_arg(const size_t idx, const IValue& arg) {
|
void BlockRunner::set_arg(const size_t idx, const IValue& arg) {
|
||||||
Input(idx + first_input_is_self_) = arg;
|
Input(idx + first_input_is_self_) = arg;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -827,24 +852,21 @@ void check_type(const Argument& schema_arg, const IValue& arg) {
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
template <typename IValueList>
|
template <typename IValueList>
|
||||||
void StaticRuntime::set_inputs(
|
void BlockRunner::set_inputs(
|
||||||
IValueList&& args,
|
IValueList&& args,
|
||||||
const std::unordered_map<std::string, c10::IValue>& kwargs) {
|
const std::unordered_map<std::string, c10::IValue>& kwargs) {
|
||||||
const auto total_num_inputs =
|
const auto total_num_inputs =
|
||||||
args.size() + kwargs.size() + first_input_is_self_;
|
args.size() + kwargs.size() + first_input_is_self_;
|
||||||
TORCH_CHECK(total_num_inputs == static_module_.num_inputs());
|
TORCH_CHECK(total_num_inputs == block_info_.num_inputs());
|
||||||
|
|
||||||
const auto& schema = static_module_.schema();
|
const auto& schema = static_module_.schema();
|
||||||
if (first_input_is_self_) {
|
if (first_input_is_self_) {
|
||||||
Input(0) = static_module_.module()._ivalue();
|
Input(0) = static_module_.module()._ivalue();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (C10_UNLIKELY(!schema)) {
|
if (!is_root_block_ || C10_UNLIKELY(!schema)) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
kwargs.empty(),
|
kwargs.empty(), "Schema is not available, but BlockRunner got kwargs.");
|
||||||
"Schema is not available, but StaticRuntime got kwargs. "
|
|
||||||
"Consider creating the Static Runtime instance "
|
|
||||||
"with StaticModule(const torch::jit::Module& m) instead.");
|
|
||||||
for (size_t i = 0; i < args.size(); ++i) {
|
for (size_t i = 0; i < args.size(); ++i) {
|
||||||
set_arg(i, std::forward<IValueList>(args));
|
set_arg(i, std::forward<IValueList>(args));
|
||||||
}
|
}
|
||||||
|
|
@ -887,15 +909,11 @@ void StaticRuntime::set_inputs(
|
||||||
args.size() + consumed_kwargs == schema_args.size() - 1);
|
args.size() + consumed_kwargs == schema_args.size() - 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void StaticRuntime::create_memory_planner() {
|
void BlockRunner::create_memory_planner() {
|
||||||
if (!planner_) {
|
if (!planner_) {
|
||||||
planner_ = std::make_unique<MemoryPlanner>(
|
planner_ = std::make_unique<MemoryPlanner>(
|
||||||
this,
|
this,
|
||||||
static_module_.value_group(),
|
block_info_,
|
||||||
static_module_.managed_tensor_values(),
|
|
||||||
static_module_.managed_output_tensor_values(),
|
|
||||||
static_module_.leaked_values(),
|
|
||||||
static_module_.managed_tensor_ranges(),
|
|
||||||
static_module_.opts().enable_out_variant,
|
static_module_.opts().enable_out_variant,
|
||||||
manage_output_tensors_enabled_,
|
manage_output_tensors_enabled_,
|
||||||
static_module_.opts().optimize_memory);
|
static_module_.opts().optimize_memory);
|
||||||
|
|
@ -924,7 +942,7 @@ void destroyNodeOutputs(ProcessedNode& p_node) {
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void StaticRuntime::clean_up_intermediate_ivalues() noexcept {
|
void BlockRunner::clean_up_intermediate_ivalues() noexcept {
|
||||||
// We have to iterate in reverse order here due to borrowed
|
// We have to iterate in reverse order here due to borrowed
|
||||||
// IValues - we don't want to destroy a value until all of its
|
// IValues - we don't want to destroy a value until all of its
|
||||||
// borrows are cleaned up!
|
// borrows are cleaned up!
|
||||||
|
|
@ -933,7 +951,7 @@ void StaticRuntime::clean_up_intermediate_ivalues() noexcept {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void StaticRuntime::resetMemory() noexcept {
|
void BlockRunner::resetMemory() noexcept {
|
||||||
planner_.reset();
|
planner_.reset();
|
||||||
// We must clean up intermediate values before inputs in case
|
// We must clean up intermediate values before inputs in case
|
||||||
// there are borrowed inputs and static runtime owns the only
|
// there are borrowed inputs and static runtime owns the only
|
||||||
|
|
@ -942,7 +960,7 @@ void StaticRuntime::resetMemory() noexcept {
|
||||||
clean_up_input_ivalues();
|
clean_up_input_ivalues();
|
||||||
}
|
}
|
||||||
|
|
||||||
c10::IValue StaticRuntime::move_outputs_to_tuple(uint32_t num_outputs) {
|
c10::IValue BlockRunner::move_outputs_to_tuple(uint32_t num_outputs) {
|
||||||
switch (num_outputs) {
|
switch (num_outputs) {
|
||||||
case 1:
|
case 1:
|
||||||
return c10::ivalue::Tuple::create(IValue(std::move(*outputs_[0])));
|
return c10::ivalue::Tuple::create(IValue(std::move(*outputs_[0])));
|
||||||
|
|
@ -1032,7 +1050,7 @@ c10::IValue StaticRuntime::move_outputs_to_tuple(uint32_t num_outputs) {
|
||||||
/// buffer) fails. There is still a corner case that fails with the added flag.
|
/// buffer) fails. There is still a corner case that fails with the added flag.
|
||||||
/// If a resize is triggered at the same time as the op creating an alias at the
|
/// If a resize is triggered at the same time as the op creating an alias at the
|
||||||
/// same time, the current checks would fail to detect the alias.
|
/// same time, the current checks would fail to detect the alias.
|
||||||
void StaticRuntime::verify_and_correct_memory_overlap(ProcessedNode& n) {
|
void BlockRunner::verify_and_correct_memory_overlap(ProcessedNode& n) {
|
||||||
// The slow check can be removed once the internal/output buffers are merged
|
// The slow check can be removed once the internal/output buffers are merged
|
||||||
if (C10_UNLIKELY(n.check_outputs_for_memory_overlap())) {
|
if (C10_UNLIKELY(n.check_outputs_for_memory_overlap())) {
|
||||||
if (C10_UNLIKELY(!planner_)) {
|
if (C10_UNLIKELY(!planner_)) {
|
||||||
|
|
@ -1065,7 +1083,7 @@ void StaticRuntime::verify_and_correct_memory_overlap(ProcessedNode& n) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool StaticRuntime::fast_check_and_correct_overlap_with(
|
bool BlockRunner::fast_check_and_correct_overlap_with(
|
||||||
ProcessedNode& n,
|
ProcessedNode& n,
|
||||||
c10::IValue& tensor_ival) {
|
c10::IValue& tensor_ival) {
|
||||||
auto& tensor = tensor_ival.toTensor();
|
auto& tensor = tensor_ival.toTensor();
|
||||||
|
|
@ -1078,38 +1096,38 @@ bool StaticRuntime::fast_check_and_correct_overlap_with(
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
StaticRuntime::Deallocator::~Deallocator() {
|
BlockRunner::Deallocator::~Deallocator() {
|
||||||
// Assume cleanup cannot throw.
|
// Assume cleanup cannot throw.
|
||||||
cleanupImpl();
|
cleanupImpl();
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
runtime_.check_for_memory_leak(/*output_returned*/ false);
|
block_runner_.check_for_memory_leak(/*output_returned*/ false);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
void StaticRuntime::Deallocator::cleanupImpl() {
|
void BlockRunner::Deallocator::cleanupImpl() {
|
||||||
// MemoryPlanner is created after the first invocation of `run()`. This
|
// MemoryPlanner is created after the first invocation of `run()`. This
|
||||||
// is done intentionally because MemoryPlanner uses `Tensor` sizes of
|
// is done intentionally because MemoryPlanner uses `Tensor` sizes of
|
||||||
// the previous `run()` for memory planning of subsequent runs
|
// the previous `run()` for memory planning of subsequent runs
|
||||||
if (C10_LIKELY(finished_)) {
|
if (C10_LIKELY(finished_)) {
|
||||||
runtime_.create_memory_planner();
|
block_runner_.create_memory_planner();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (C10_LIKELY(runtime_.planner_)) {
|
if (C10_LIKELY(block_runner_.planner_)) {
|
||||||
runtime_.planner_->deallocate();
|
block_runner_.planner_->deallocate();
|
||||||
} else {
|
} else {
|
||||||
// This is the first run, and it didn't finish, so we can't use a
|
// This is the first run, and it didn't finish, so we can't use a
|
||||||
// `MemoryPlanner` to deallocate stuff. Just reset everything mannually.
|
// `MemoryPlanner` to deallocate stuff. Just reset everything mannually.
|
||||||
runtime_.resetMemory();
|
block_runner_.resetMemory();
|
||||||
}
|
}
|
||||||
// clean up owning refs of input tensors
|
// clean up owning refs of input tensors
|
||||||
runtime_.clean_up_input_ivalues();
|
block_runner_.clean_up_input_ivalues();
|
||||||
if (C10_UNLIKELY(!finished_)) {
|
if (C10_UNLIKELY(!finished_)) {
|
||||||
runtime_.deallocateOutputTensors();
|
block_runner_.deallocateOutputTensors();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename IValueList>
|
template <typename IValueList>
|
||||||
c10::IValue StaticRuntime::run_impl(
|
c10::IValue BlockRunner::run_impl(
|
||||||
IValueList&& args,
|
IValueList&& args,
|
||||||
const KeywordArgs& kwargs) {
|
const KeywordArgs& kwargs) {
|
||||||
// We assume inference workloads, so we do not need
|
// We assume inference workloads, so we do not need
|
||||||
|
|
@ -1138,8 +1156,8 @@ c10::IValue StaticRuntime::run_impl(
|
||||||
}
|
}
|
||||||
|
|
||||||
// no need to keep references of outputs in static runtime anymore
|
// no need to keep references of outputs in static runtime anymore
|
||||||
if (static_module_.num_outputs() > 1) {
|
if (block_info_.num_outputs() > 1) {
|
||||||
return move_outputs_to_tuple(static_module_.num_outputs());
|
return move_outputs_to_tuple(block_info_.num_outputs());
|
||||||
}
|
}
|
||||||
|
|
||||||
DCHECK(check_for_memory_leak(/*output_returned*/ false));
|
DCHECK(check_for_memory_leak(/*output_returned*/ false));
|
||||||
|
|
@ -1149,7 +1167,7 @@ c10::IValue StaticRuntime::run_impl(
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename IValueList>
|
template <typename IValueList>
|
||||||
c10::IValue StaticRuntime::run_impl_record_functions(
|
c10::IValue BlockRunner::run_impl_record_functions(
|
||||||
IValueList&& args,
|
IValueList&& args,
|
||||||
const KeywordArgs& kwargs) {
|
const KeywordArgs& kwargs) {
|
||||||
bool pre_sampled = false;
|
bool pre_sampled = false;
|
||||||
|
|
@ -1168,7 +1186,7 @@ c10::IValue StaticRuntime::run_impl_record_functions(
|
||||||
return run_impl(std::forward<IValueList>(args), kwargs);
|
return run_impl(std::forward<IValueList>(args), kwargs);
|
||||||
}
|
}
|
||||||
|
|
||||||
c10::IValue StaticRuntime::operator()(
|
c10::IValue BlockRunner::operator()(
|
||||||
const std::vector<c10::IValue>& args,
|
const std::vector<c10::IValue>& args,
|
||||||
const KeywordArgs& kwargs) {
|
const KeywordArgs& kwargs) {
|
||||||
#ifdef PYTORCH_DISABLE_NET_PROFILING
|
#ifdef PYTORCH_DISABLE_NET_PROFILING
|
||||||
|
|
@ -1178,7 +1196,7 @@ c10::IValue StaticRuntime::operator()(
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
c10::IValue StaticRuntime::operator()(
|
c10::IValue BlockRunner::operator()(
|
||||||
std::vector<c10::IValue>&& args,
|
std::vector<c10::IValue>&& args,
|
||||||
const KeywordArgs& kwargs) {
|
const KeywordArgs& kwargs) {
|
||||||
#ifdef PYTORCH_DISABLE_NET_PROFILING
|
#ifdef PYTORCH_DISABLE_NET_PROFILING
|
||||||
|
|
@ -1205,7 +1223,7 @@ std::string generate_latency_json(const std::string& label, double millis) {
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void StaticRuntime::benchmark(
|
void BlockRunner::benchmark(
|
||||||
const std::vector<std::vector<c10::IValue>>& args_list,
|
const std::vector<std::vector<c10::IValue>>& args_list,
|
||||||
const std::vector<KeywordArgs>& kwargs_list,
|
const std::vector<KeywordArgs>& kwargs_list,
|
||||||
const int warmup_runs,
|
const int warmup_runs,
|
||||||
|
|
@ -1267,7 +1285,7 @@ void StaticRuntime::benchmark(
|
||||||
}
|
}
|
||||||
std::cout << std::setw(15) << results.total_time << " ms. in Total"
|
std::cout << std::setw(15) << results.total_time << " ms. in Total"
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
std::cout << "StaticRuntime setup time: " << results.setup_time << " ms"
|
std::cout << "BlockRunner setup time: " << results.setup_time << " ms"
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
std::cout << "Memory allocation time: " << results.memory_alloc_time
|
std::cout << "Memory allocation time: " << results.memory_alloc_time
|
||||||
<< " ms\n";
|
<< " ms\n";
|
||||||
|
|
@ -1312,7 +1330,7 @@ void StaticRuntime::benchmark(
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
float StaticRuntime::benchmark_model(
|
float BlockRunner::benchmark_model(
|
||||||
const std::vector<std::vector<c10::IValue>>& args_list,
|
const std::vector<std::vector<c10::IValue>>& args_list,
|
||||||
const std::vector<KeywordArgs>& kwargs_list,
|
const std::vector<KeywordArgs>& kwargs_list,
|
||||||
const int warmup_runs,
|
const int warmup_runs,
|
||||||
|
|
@ -1396,7 +1414,7 @@ void display_pnode_info(const ProcessedNode& pnode) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void StaticRuntime::display_nodes(
|
void BlockRunner::display_nodes(
|
||||||
const std::vector<c10::IValue>& args,
|
const std::vector<c10::IValue>& args,
|
||||||
const KeywordArgs& kwargs) {
|
const KeywordArgs& kwargs) {
|
||||||
c10::InferenceMode mode;
|
c10::InferenceMode mode;
|
||||||
|
|
@ -1415,7 +1433,7 @@ void StaticRuntime::display_nodes(
|
||||||
on_exit.setFinished();
|
on_exit.setFinished();
|
||||||
}
|
}
|
||||||
|
|
||||||
StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
|
BlockRunner::IndividualMetrics BlockRunner::benchmark_individual_ops(
|
||||||
const std::vector<std::vector<c10::IValue>>& args_list,
|
const std::vector<std::vector<c10::IValue>>& args_list,
|
||||||
const std::vector<KeywordArgs>& kwargs_list,
|
const std::vector<KeywordArgs>& kwargs_list,
|
||||||
const int warmup_runs,
|
const int warmup_runs,
|
||||||
|
|
@ -1543,10 +1561,16 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
|
||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool StaticRuntime::check_for_memory_leak(bool output_returned) {
|
bool BlockRunner::check_for_memory_leak(
|
||||||
|
bool output_returned,
|
||||||
|
bool recurse_on_sub_blocks) {
|
||||||
// check for inputs
|
// check for inputs
|
||||||
for (const auto i : c10::irange(static_module_.num_inputs())) {
|
for (const auto i : c10::irange(block_info_.num_inputs())) {
|
||||||
TORCH_CHECK(values_[i].isNone(), "Input ", i, " was not cleaned up");
|
TORCH_CHECK(
|
||||||
|
values_[i + block_info_.block_inputs_idx()].isNone(),
|
||||||
|
"Input ",
|
||||||
|
i,
|
||||||
|
" was not cleaned up");
|
||||||
}
|
}
|
||||||
FastSet<const IValue*> output_ivalues(outputs_.begin(), outputs_.end());
|
FastSet<const IValue*> output_ivalues(outputs_.begin(), outputs_.end());
|
||||||
for (const auto n : c10::irange(nodes_.size())) {
|
for (const auto n : c10::irange(nodes_.size())) {
|
||||||
|
|
@ -1561,7 +1585,8 @@ bool StaticRuntime::check_for_memory_leak(bool output_returned) {
|
||||||
(isManagedOutputTensor(*ival) || isManagedOutputTensorValue(val))) {
|
(isManagedOutputTensor(*ival) || isManagedOutputTensorValue(val))) {
|
||||||
// `ival` contains a managed output tensor that the runtime doesn't
|
// `ival` contains a managed output tensor that the runtime doesn't
|
||||||
// reclaim at the end of an iteration, but the client does so
|
// reclaim at the end of an iteration, but the client does so
|
||||||
// by explicitly calling `StaticRuntime::deallocateOutputTensors`.
|
// by explicitly calling
|
||||||
|
// `BlockRunner::deallocateOutputTensors`.
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const std::string error_msg = "Output " + c10::to_string(i) + ", %" +
|
const std::string error_msg = "Output " + c10::to_string(i) + ", %" +
|
||||||
|
|
@ -1573,7 +1598,8 @@ bool StaticRuntime::check_for_memory_leak(bool output_returned) {
|
||||||
if (!ival->isNone()) {
|
if (!ival->isNone()) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
ival->isTensor() ||
|
ival->isTensor() ||
|
||||||
static_module_.is_optimizable_container_type(pnode.node()) ||
|
block_info_.node_is_optimizable_container_type(
|
||||||
|
pnode.node()) ||
|
||||||
doesNotHeapAllocateWhenStoredInIValue(*val->type()),
|
doesNotHeapAllocateWhenStoredInIValue(*val->type()),
|
||||||
error_msg);
|
error_msg);
|
||||||
if (ival->isTensor()) {
|
if (ival->isTensor()) {
|
||||||
|
|
@ -1595,12 +1621,20 @@ bool StaticRuntime::check_for_memory_leak(bool output_returned) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto* block_runners = pnode.block_runners();
|
||||||
|
if (recurse_on_sub_blocks && block_runners) {
|
||||||
|
for (auto& block_runner : *block_runners) {
|
||||||
|
block_runner.check_for_memory_leak(
|
||||||
|
output_returned, recurse_on_sub_blocks);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
VLOG(1) << "Finished checking for memory leak";
|
VLOG(1) << "Finished checking for memory leak";
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void StaticRuntime::deallocateOutputTensors() {
|
void BlockRunner::deallocateOutputTensors() {
|
||||||
if (!static_module_.opts().manage_output_tensors) {
|
if (!static_module_.opts().manage_output_tensors) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
!planner_ || planner_->numOutputBufferBytes() == 0,
|
!planner_ || planner_->numOutputBufferBytes() == 0,
|
||||||
|
|
@ -1613,7 +1647,7 @@ void StaticRuntime::deallocateOutputTensors() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool StaticRuntime::checkOutputTensorMemoryLeaks() {
|
bool BlockRunner::checkOutputTensorMemoryLeaks() {
|
||||||
if (!static_module_.opts().manage_output_tensors || !planner_) {
|
if (!static_module_.opts().manage_output_tensors || !planner_) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
@ -1639,21 +1673,21 @@ bool StaticRuntime::checkOutputTensorMemoryLeaks() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool StaticRuntime::isManagedOutputTensor(const IValue& ivalue) const {
|
bool BlockRunner::isManagedOutputTensor(const IValue& ivalue) const {
|
||||||
return planner_ && planner_->isManagedOutputTensor(ivalue);
|
return planner_ && planner_->isManagedOutputTensor(ivalue);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool StaticRuntime::isManagedOutputTensorValue(const Value* value) const {
|
bool BlockRunner::isManagedOutputTensorValue(const Value* value) const {
|
||||||
// It's possible that manage_output_tensors_ was disabled after initializing
|
// It's possible that manage_output_tensors_ was disabled after initializing
|
||||||
// managed_output_tensor_values, so we have to check that flag here.
|
// managed_output_tensor_values, so we have to check that flag here.
|
||||||
if (!planner_ || !manage_output_tensors_enabled_) {
|
if (!planner_ || !manage_output_tensors_enabled_) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
const auto& managed_outputs = static_module_.managed_output_tensor_values();
|
const auto& managed_outputs = block_info_.managed_output_tensor_values();
|
||||||
return managed_outputs.find(value) != managed_outputs.end();
|
return managed_outputs.find(value) != managed_outputs.end();
|
||||||
}
|
}
|
||||||
|
|
||||||
void StaticRuntime::disableManageOutputTensors() {
|
void BlockRunner::disableManageOutputTensors() {
|
||||||
if (!manage_output_tensors_enabled_) {
|
if (!manage_output_tensors_enabled_) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
@ -1915,5 +1949,50 @@ void ProcessedNode::verify_and_correct_memory_overlap() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StaticRuntime::StaticRuntime(const StaticModule& sm) {
|
||||||
|
values_.resize(sm.value_buffer_size());
|
||||||
|
std::copy(sm.constants().begin(), sm.constants().end(), values_.begin());
|
||||||
|
block_ = std::make_unique<BlockRunner>(
|
||||||
|
sm, values_, sm.root_block(), /*is_root_block*/ true);
|
||||||
|
;
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::IValue StaticRuntime::operator()(
|
||||||
|
const std::vector<c10::IValue>& args,
|
||||||
|
const KeywordArgs& kwargs) {
|
||||||
|
return (*block_)(args, kwargs);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::IValue StaticRuntime::operator()(
|
||||||
|
std::vector<c10::IValue>&& args,
|
||||||
|
const KeywordArgs& kwargs) {
|
||||||
|
return (*block_)(std::move(args), kwargs);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool StaticRuntime::check_for_memory_leak(bool output_returned) {
|
||||||
|
return block_->check_for_memory_leak(
|
||||||
|
output_returned, /* recurse_on_sub_blocks */ true);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool StaticRuntime::checkOutputTensorMemoryLeaks() {
|
||||||
|
return block_->checkOutputTensorMemoryLeaks();
|
||||||
|
}
|
||||||
|
|
||||||
|
void StaticRuntime::deallocateOutputTensors() {
|
||||||
|
block_->deallocateOutputTensors();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool StaticRuntime::isManagedOutputTensor(const IValue& ivalue) const {
|
||||||
|
return block_->isManagedOutputTensor(ivalue);
|
||||||
|
}
|
||||||
|
|
||||||
|
void StaticRuntime::disableManageOutputTensors() {
|
||||||
|
block_->disableManageOutputTensors();
|
||||||
|
}
|
||||||
|
|
||||||
|
const MemoryPlanner* StaticRuntime::get_memory_planner() const {
|
||||||
|
return block_->get_memory_planner();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@
|
||||||
#include <torch/csrc/jit/passes/freeze_module.h>
|
#include <torch/csrc/jit/passes/freeze_module.h>
|
||||||
#include <torch/csrc/jit/passes/inliner.h>
|
#include <torch/csrc/jit/passes/inliner.h>
|
||||||
#include <torch/csrc/jit/runtime/static/ProcessedNodeInputs.h>
|
#include <torch/csrc/jit/runtime/static/ProcessedNodeInputs.h>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
#ifdef FBCODE_CAFFE2
|
#ifdef FBCODE_CAFFE2
|
||||||
#include <folly/container/F14Map.h>
|
#include <folly/container/F14Map.h>
|
||||||
|
|
@ -82,7 +83,7 @@ TORCH_API inline bool borrowsOutputs(c10::Symbol kind) {
|
||||||
class ValueGroup {
|
class ValueGroup {
|
||||||
public:
|
public:
|
||||||
explicit ValueGroup() = default;
|
explicit ValueGroup() = default;
|
||||||
void init(const std::shared_ptr<torch::jit::Graph>& graph, AliasDb& db);
|
void init(const Block& block, const AliasDb& db);
|
||||||
|
|
||||||
bool isExternalAlias(const Value* value) const {
|
bool isExternalAlias(const Value* value) const {
|
||||||
return external_aliases_.find(value) != external_aliases_.end();
|
return external_aliases_.find(value) != external_aliases_.end();
|
||||||
|
|
@ -112,7 +113,8 @@ class TORCH_API ManagedTensorRanges {
|
||||||
public:
|
public:
|
||||||
ManagedTensorRanges() = default;
|
ManagedTensorRanges() = default;
|
||||||
ManagedTensorRanges(
|
ManagedTensorRanges(
|
||||||
const std::shared_ptr<Graph>& graph,
|
Block& block,
|
||||||
|
const AliasDb& alias_db,
|
||||||
const FastSet<const Value*>& managed_tensor_values);
|
const FastSet<const Value*>& managed_tensor_values);
|
||||||
|
|
||||||
// If true, then this node is the last use of at least one
|
// If true, then this node is the last use of at least one
|
||||||
|
|
@ -213,11 +215,122 @@ struct TORCH_API StaticModuleOptions {
|
||||||
/// pool.push(runtime);
|
/// pool.push(runtime);
|
||||||
/// @endcode
|
/// @endcode
|
||||||
///
|
///
|
||||||
|
|
||||||
class MemoryPlanner;
|
class MemoryPlanner;
|
||||||
class ProcessedFunction;
|
class ProcessedFunction;
|
||||||
class ProcessedNode;
|
class ProcessedNode;
|
||||||
class StaticRuntime;
|
class StaticRuntime;
|
||||||
|
|
||||||
|
// A `BlockInfo` instance stores all of the shared state that each
|
||||||
|
// `BlockRunner` will need to access. Most of this information is
|
||||||
|
// read-only and shared between threads.
|
||||||
|
// - Each `BlockInfo` corresponds to one block in the graph.
|
||||||
|
// - Each `BlockInfo` may be used by multiple block runners (when there are many
|
||||||
|
// threads).
|
||||||
|
// - All of the `BlockInfo`s are stored in a vector in the `StaticModule` and
|
||||||
|
// are initialized during `StaticModule` construction.
|
||||||
|
// - Most of the information stored is used to initialize the block's memory
|
||||||
|
// planner.
|
||||||
|
class BlockInfo {
|
||||||
|
public:
|
||||||
|
BlockInfo(uint32_t input_idx, Block& block)
|
||||||
|
: input_idx_(input_idx), block_(block) {}
|
||||||
|
|
||||||
|
void set_nodes(
|
||||||
|
std::vector<ProcessedNode> nodes,
|
||||||
|
const FastMap<Node*, bool>& node_has_out_variant);
|
||||||
|
|
||||||
|
const std::vector<ProcessedNode>& nodes() const {
|
||||||
|
return nodes_;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t num_nodes() const {
|
||||||
|
return nodes_.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t num_inputs() const {
|
||||||
|
return block_.inputs().size();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t num_outputs() const {
|
||||||
|
return block_.outputs().size();
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_node_list node_ptrs() const {
|
||||||
|
return block_.nodes();
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_output_indices(std::vector<uint16_t> indices) {
|
||||||
|
output_indices_ = std::move(indices);
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<uint16_t>& block_output_indices() const {
|
||||||
|
return output_indices_;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto block_inputs_idx() const {
|
||||||
|
return input_idx_;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool node_is_optimizable_container_type(const Node* node) const {
|
||||||
|
return node_is_optimizable_container_type_.find(node) !=
|
||||||
|
node_is_optimizable_container_type_.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool value_is_managed_tensor(const Value* value) const {
|
||||||
|
return managed_tensor_values_.find(value) != managed_tensor_values_.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool value_is_leaked_container(const Value* value) const {
|
||||||
|
return leaked_values_.find(value) != leaked_values_.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
const ValueGroup& value_group() const {
|
||||||
|
return value_group_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const ManagedTensorRanges& managed_tensor_ranges() const {
|
||||||
|
return managed_tensor_ranges_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void init_value_group(const AliasDb& alias_db) {
|
||||||
|
value_group_.init(block_, alias_db);
|
||||||
|
}
|
||||||
|
|
||||||
|
void prepare_for_memory_planner(
|
||||||
|
const AliasDb& alias_db,
|
||||||
|
const StaticModuleOptions& opt);
|
||||||
|
|
||||||
|
const auto& managed_output_tensor_values() const {
|
||||||
|
return managed_output_tensor_values_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto& managed_tensor_values() const {
|
||||||
|
return managed_tensor_values_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto& leaked_values() const {
|
||||||
|
return leaked_values_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<ProcessedNode> nodes_;
|
||||||
|
|
||||||
|
ValueGroup value_group_;
|
||||||
|
|
||||||
|
FastSet<const Node*> node_is_optimizable_container_type_;
|
||||||
|
FastSet<const Value*> managed_tensor_values_;
|
||||||
|
FastSet<const Value*> managed_output_tensor_values_;
|
||||||
|
FastSet<const Value*> leaked_values_;
|
||||||
|
|
||||||
|
ManagedTensorRanges managed_tensor_ranges_{};
|
||||||
|
|
||||||
|
// The index of this block's inputs in the shared values_ array.
|
||||||
|
const uint16_t input_idx_;
|
||||||
|
// The indices of this block's outputs in the shared values_ array.
|
||||||
|
std::vector<uint16_t> output_indices_;
|
||||||
|
Block& block_;
|
||||||
|
};
|
||||||
|
|
||||||
class TORCH_API StaticModule {
|
class TORCH_API StaticModule {
|
||||||
public:
|
public:
|
||||||
explicit StaticModule(
|
explicit StaticModule(
|
||||||
|
|
@ -231,23 +344,12 @@ class TORCH_API StaticModule {
|
||||||
const StaticModuleOptions& opts = StaticModuleOptions(),
|
const StaticModuleOptions& opts = StaticModuleOptions(),
|
||||||
std::vector<IValue> sample_inputs = {});
|
std::vector<IValue> sample_inputs = {});
|
||||||
|
|
||||||
typedef enum {
|
|
||||||
CONSTANT_VALUE = -2, // VALUE nodes defined by prim::Constant
|
|
||||||
INPUT_VALUE = -1, // VALUE nodes representing graph inputs
|
|
||||||
} VALUE_KIND;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
explicit StaticModule(
|
explicit StaticModule(
|
||||||
std::pair<std::shared_ptr<torch::jit::Graph>, c10::optional<Module>>
|
std::pair<std::shared_ptr<torch::jit::Graph>, c10::optional<Module>>
|
||||||
graph_and_module,
|
graph_and_module,
|
||||||
const StaticModuleOptions& opts);
|
const StaticModuleOptions& opts);
|
||||||
|
|
||||||
// for <kind, idx>
|
|
||||||
// if kind == CONSTANT_VALUE: map to constants_[idx]
|
|
||||||
// if kind == INPUT_VALUE: map to inputs_[idx]
|
|
||||||
// otherwise: map to nodes_[kind].outputs()[idx]
|
|
||||||
using DefInfo = std::pair<int, int>;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
using KeywordArgs = std::unordered_map<std::string, c10::IValue>;
|
using KeywordArgs = std::unordered_map<std::string, c10::IValue>;
|
||||||
c10::IValue operator()(
|
c10::IValue operator()(
|
||||||
|
|
@ -268,10 +370,6 @@ class TORCH_API StaticModule {
|
||||||
|
|
||||||
const StaticModuleOptions& opts() const;
|
const StaticModuleOptions& opts() const;
|
||||||
|
|
||||||
const ValueGroup& valueGroup() const {
|
|
||||||
return value_group_;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t num_inputs() const;
|
size_t num_inputs() const;
|
||||||
size_t num_outputs() const;
|
size_t num_outputs() const;
|
||||||
|
|
||||||
|
|
@ -295,74 +393,69 @@ class TORCH_API StaticModule {
|
||||||
return constants_;
|
return constants_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const BlockInfo& block_info(Block* block) const {
|
||||||
|
return block_infos_.at(block);
|
||||||
|
}
|
||||||
|
|
||||||
|
Block* root_block() const {
|
||||||
|
return graph_->block();
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class StaticRuntime;
|
friend class StaticRuntime;
|
||||||
|
friend class BlockRunner;
|
||||||
// Our nodes don't have their inputs & outputs initialized; don't
|
|
||||||
// let anybody but StaticRuntime and tests get them.
|
|
||||||
const std::vector<ProcessedNode>& nodes() const {
|
|
||||||
return nodes_;
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
auto num_nodes() const {
|
auto num_nodes() const {
|
||||||
return nodes_.size();
|
return std::accumulate(
|
||||||
|
block_infos_.begin(),
|
||||||
|
block_infos_.end(),
|
||||||
|
0,
|
||||||
|
[](size_t sum, const auto& block_and_info) {
|
||||||
|
auto& block_info = block_and_info.second;
|
||||||
|
return sum + block_info.num_nodes();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
C10_NODISCARD Node* findNodeWithKindForTesting(const std::string& kind) const;
|
C10_NODISCARD Node* findNodeWithKindForTesting(const std::string& kind) const;
|
||||||
|
|
||||||
graph_node_list node_ptrs() const {
|
|
||||||
return graph_->nodes();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_optimizable_container_type(const Node* n) const {
|
|
||||||
auto it = node_is_optimizable_container_type_.find(n);
|
|
||||||
return it != node_is_optimizable_container_type_.end();
|
|
||||||
}
|
|
||||||
|
|
||||||
const c10::optional<c10::FunctionSchema>& schema() const {
|
const c10::optional<c10::FunctionSchema>& schema() const {
|
||||||
return schema_;
|
return schema_;
|
||||||
}
|
}
|
||||||
|
|
||||||
const ValueGroup& value_group() const {
|
|
||||||
return value_group_;
|
|
||||||
}
|
|
||||||
|
|
||||||
const FastSet<const Value*>& managed_tensor_values() const {
|
|
||||||
return managed_tensor_values_;
|
|
||||||
}
|
|
||||||
|
|
||||||
const FastSet<const Value*>& managed_output_tensor_values() const {
|
|
||||||
return managed_output_tensor_values_;
|
|
||||||
}
|
|
||||||
|
|
||||||
const FastSet<const Value*>& leaked_values() const {
|
|
||||||
return leaked_values_;
|
|
||||||
}
|
|
||||||
|
|
||||||
const ManagedTensorRanges& managed_tensor_ranges() const {
|
|
||||||
return managed_tensor_ranges_;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool first_input_is_self() const {
|
bool first_input_is_self() const {
|
||||||
return module_.has_value();
|
return module_.has_value();
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t inputs_offset() const {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t constants_offset() const {
|
|
||||||
return inputs_offset() + num_inputs();
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t intermediate_values_offset() const {
|
|
||||||
return constants_offset() + num_constants();
|
|
||||||
}
|
|
||||||
|
|
||||||
StaticRuntime& runtime();
|
StaticRuntime& runtime();
|
||||||
|
|
||||||
|
// See [Shared values array]
|
||||||
|
size_t value_buffer_size() const {
|
||||||
|
return value_buffer_size_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
// Recursively prepares the BlockInfo array.
|
||||||
|
// - Populates `value_to_index` with the indices of each intermediate value
|
||||||
|
// - Returns the number of Value* processed, including sub-blocks.
|
||||||
|
size_t prepareBlockInfo(
|
||||||
|
Block* block,
|
||||||
|
const size_t start_idx,
|
||||||
|
FastMap<const Value*, uint32_t>& value_to_index);
|
||||||
|
|
||||||
|
void prepareFunctionsAndConstants(
|
||||||
|
Block* block,
|
||||||
|
const AliasDb& alias_db,
|
||||||
|
FastMap<const Value*, uint32_t>& value_to_index);
|
||||||
|
|
||||||
|
// Recurses on sub-blocks and populates the array of ProcessedNodes
|
||||||
|
// Returns (number of nodes processed, number of blocks processed)
|
||||||
|
size_t prepareProcessedNodes(
|
||||||
|
Block* block,
|
||||||
|
const FastMap<const Value*, uint32_t>& value_to_index,
|
||||||
|
const AliasDb& alias_db,
|
||||||
|
size_t node_idx = 0);
|
||||||
|
|
||||||
// Initialize various attributes that the memory planner will need.
|
// Initialize various attributes that the memory planner will need.
|
||||||
// To be called at the tail of the ctor.
|
// To be called at the tail of the ctor.
|
||||||
void prepareForMemoryPlanner();
|
void prepareForMemoryPlanner();
|
||||||
|
|
@ -383,15 +476,6 @@ class TORCH_API StaticModule {
|
||||||
// Indices of graph outputs in the single values array.
|
// Indices of graph outputs in the single values array.
|
||||||
std::vector<uint16_t> output_indices_;
|
std::vector<uint16_t> output_indices_;
|
||||||
|
|
||||||
ValueGroup value_group_;
|
|
||||||
|
|
||||||
FastSet<const Node*> node_is_optimizable_container_type_;
|
|
||||||
|
|
||||||
FastSet<const Value*> managed_tensor_values_{};
|
|
||||||
FastSet<const Value*> managed_output_tensor_values_{};
|
|
||||||
FastSet<const Value*> leaked_values_{};
|
|
||||||
ManagedTensorRanges managed_tensor_ranges_{};
|
|
||||||
|
|
||||||
size_t num_intermediate_values_ = 0;
|
size_t num_intermediate_values_ = 0;
|
||||||
|
|
||||||
// Includes self if module_ != nullopt.
|
// Includes self if module_ != nullopt.
|
||||||
|
|
@ -399,16 +483,33 @@ class TORCH_API StaticModule {
|
||||||
// argument. In this case, `self` isn't used in the graph, but the schema
|
// argument. In this case, `self` isn't used in the graph, but the schema
|
||||||
// includes it anyways to be consistent with the JIT interpreter.
|
// includes it anyways to be consistent with the JIT interpreter.
|
||||||
size_t num_inputs_;
|
size_t num_inputs_;
|
||||||
|
// See `BlockInfo` definition. The blocks are stored in depth-first order.
|
||||||
|
FastMap<Block*, BlockInfo> block_infos_;
|
||||||
|
size_t value_buffer_size_ = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
class TORCH_API StaticRuntime {
|
// `BlockRunner` contains the core runtime logic. Each block runner
|
||||||
|
// corresponds to one block in the graph and has its own memory planner.
|
||||||
|
// `StaticRuntime` will initialize all `BlockRunner`s
|
||||||
|
// upon construction. Each block runner only directly executes nodes from its
|
||||||
|
// block. Special ops with sub-blocks like `prim::If` may have
|
||||||
|
// `BlockRunner`s stored in their `ProcessedNode`s; these
|
||||||
|
// sub-blocks get executed in the op's implementation.
|
||||||
|
// `StaticRuntime` stores a vector of IValues that all
|
||||||
|
// `BlockRunner`s share. This vector is used to store all
|
||||||
|
// constants, inputs, and intermediate tensors.
|
||||||
|
class TORCH_API BlockRunner {
|
||||||
public:
|
public:
|
||||||
explicit StaticRuntime(const StaticModule& sm);
|
BlockRunner(
|
||||||
StaticRuntime(StaticRuntime&&) = delete;
|
const StaticModule& sm,
|
||||||
StaticRuntime& operator=(StaticRuntime&&) = delete;
|
std::vector<IValue>& values,
|
||||||
~StaticRuntime();
|
Block* block,
|
||||||
|
bool is_root_block = false);
|
||||||
|
BlockRunner(BlockRunner&&) noexcept;
|
||||||
|
BlockRunner& operator=(BlockRunner&&) = delete;
|
||||||
|
~BlockRunner();
|
||||||
|
|
||||||
C10_DISABLE_COPY_AND_ASSIGN(StaticRuntime);
|
C10_DISABLE_COPY_AND_ASSIGN(BlockRunner);
|
||||||
|
|
||||||
using KeywordArgs = std::unordered_map<std::string, c10::IValue>;
|
using KeywordArgs = std::unordered_map<std::string, c10::IValue>;
|
||||||
c10::IValue operator()(
|
c10::IValue operator()(
|
||||||
|
|
@ -451,11 +552,16 @@ class TORCH_API StaticRuntime {
|
||||||
|
|
||||||
// Input is readwrite
|
// Input is readwrite
|
||||||
IValue& Input(uint32_t i) {
|
IValue& Input(uint32_t i) {
|
||||||
DCHECK_LT(i, static_module_.num_inputs());
|
DCHECK_LT(i, block_info_.num_inputs());
|
||||||
DCHECK_LT(i, values_.size());
|
DCHECK_LT(i, values_.size());
|
||||||
return values_[i];
|
return values_[i + block_info_.block_inputs_idx()];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t init_sub_blocks(
|
||||||
|
const StaticModule& sm,
|
||||||
|
std::vector<IValue>& values,
|
||||||
|
size_t block_idx);
|
||||||
|
|
||||||
// Output is readonly. The writing process happens inside ProcessedNodes
|
// Output is readonly. The writing process happens inside ProcessedNodes
|
||||||
C10_NODISCARD const IValue& Output(uint32_t i) const {
|
C10_NODISCARD const IValue& Output(uint32_t i) const {
|
||||||
DCHECK(i < outputs_.size());
|
DCHECK(i < outputs_.size());
|
||||||
|
|
@ -475,7 +581,7 @@ class TORCH_API StaticRuntime {
|
||||||
}
|
}
|
||||||
|
|
||||||
graph_node_list node_ptrs() const {
|
graph_node_list node_ptrs() const {
|
||||||
return static_module_.node_ptrs();
|
return block_info_.node_ptrs();
|
||||||
}
|
}
|
||||||
|
|
||||||
const Graph& graph() const {
|
const Graph& graph() const {
|
||||||
|
|
@ -486,11 +592,9 @@ class TORCH_API StaticRuntime {
|
||||||
return planner_.get();
|
return planner_.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool check_for_memory_leak(bool output_returned = true);
|
bool check_for_memory_leak(
|
||||||
|
bool output_returned = true,
|
||||||
bool is_optimizable_container_type(Node* n) const {
|
bool recurse_on_sub_blocks = false);
|
||||||
return static_module_.is_optimizable_container_type(n);
|
|
||||||
}
|
|
||||||
|
|
||||||
// WARNING: Deallocate managed output tensors. A client receiving Static
|
// WARNING: Deallocate managed output tensors. A client receiving Static
|
||||||
// Runtime-managed Tensors needs to be very careful to call
|
// Runtime-managed Tensors needs to be very careful to call
|
||||||
|
|
@ -521,7 +625,8 @@ class TORCH_API StaticRuntime {
|
||||||
// when destructed.
|
// when destructed.
|
||||||
class Deallocator {
|
class Deallocator {
|
||||||
public:
|
public:
|
||||||
explicit Deallocator(StaticRuntime& runtime) : runtime_(runtime) {}
|
explicit Deallocator(BlockRunner& block_runner)
|
||||||
|
: block_runner_(block_runner) {}
|
||||||
|
|
||||||
Deallocator(Deallocator&&) = default;
|
Deallocator(Deallocator&&) = default;
|
||||||
Deallocator(const Deallocator&) = default;
|
Deallocator(const Deallocator&) = default;
|
||||||
|
|
@ -537,7 +642,7 @@ class TORCH_API StaticRuntime {
|
||||||
void cleanupImpl();
|
void cleanupImpl();
|
||||||
|
|
||||||
bool finished_ = false;
|
bool finished_ = false;
|
||||||
StaticRuntime& runtime_;
|
BlockRunner& block_runner_;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename IValueList>
|
template <typename IValueList>
|
||||||
|
|
@ -569,8 +674,8 @@ class TORCH_API StaticRuntime {
|
||||||
|
|
||||||
// clean up owning refs of input IValues
|
// clean up owning refs of input IValues
|
||||||
void clean_up_input_ivalues() noexcept {
|
void clean_up_input_ivalues() noexcept {
|
||||||
for (const auto idx : c10::irange(static_module_.num_inputs())) {
|
for (const auto idx : c10::irange(block_info_.num_inputs())) {
|
||||||
values_[idx] = IValue();
|
values_[idx + inputs_begin_] = IValue();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -591,16 +696,29 @@ class TORCH_API StaticRuntime {
|
||||||
const KeywordArgs& kwargs);
|
const KeywordArgs& kwargs);
|
||||||
|
|
||||||
const StaticModule& static_module_;
|
const StaticModule& static_module_;
|
||||||
|
const BlockInfo& block_info_;
|
||||||
|
|
||||||
|
const bool is_root_block_;
|
||||||
// Cache this so we don't have to call static_module_.first_input_is_self()
|
// Cache this so we don't have to call static_module_.first_input_is_self()
|
||||||
const bool first_input_is_self_;
|
const bool first_input_is_self_;
|
||||||
|
// Index of the start of this blocks inputs in the shared values_ array.
|
||||||
|
const uint16_t inputs_begin_;
|
||||||
|
|
||||||
bool manage_output_tensors_enabled_ = false;
|
bool manage_output_tensors_enabled_ = false;
|
||||||
std::unique_ptr<MemoryPlanner> planner_;
|
std::unique_ptr<MemoryPlanner> planner_;
|
||||||
// first static_module_.num_inputs() slots are inputs, next
|
// [Shared values array]
|
||||||
// static_module_.constants().size() slots are a copy of
|
// ProcessedNodes reference their inputs and outputs with
|
||||||
// static_module_.constants(), rest are regular values in the
|
|
||||||
// graph. ProcessedNodes reference their inputs and outputs with
|
|
||||||
// offsets into this array, which saves memory.
|
// offsets into this array, which saves memory.
|
||||||
std::vector<IValue> values_;
|
// All BlockRunners share the same array. The layout is as
|
||||||
|
// follows:
|
||||||
|
// [constants][block_0][block_1]...[block_N]
|
||||||
|
// Note that constants from all blocks are pooled together at the start.
|
||||||
|
// The block ordering is depth-first.
|
||||||
|
// Each block is further divided into inputs and intermediates:
|
||||||
|
// [block_i] = [inputs_i][intermediates_i]
|
||||||
|
// Each BlockRunner knows where its inputs start. Each ProcessedNode
|
||||||
|
// knows how to find the indices of its outputs/inputs in this array.
|
||||||
|
std::vector<IValue>& values_;
|
||||||
std::vector<IValue*> outputs_;
|
std::vector<IValue*> outputs_;
|
||||||
std::vector<ProcessedNode> nodes_;
|
std::vector<ProcessedNode> nodes_;
|
||||||
};
|
};
|
||||||
|
|
@ -643,15 +761,55 @@ class TORCH_API ProcessedNode {
|
||||||
ProcessedNode() = default;
|
ProcessedNode() = default;
|
||||||
// ProcessedNodes are created within StaticModule and then
|
// ProcessedNodes are created within StaticModule and then
|
||||||
// associated with a shared values array using set_values() when
|
// associated with a shared values array using set_values() when
|
||||||
// they are copied into a StaticRuntime.
|
// they are copied into a StaticRuntime. block_runners_ are also
|
||||||
|
// not initialized until StaticRuntime initialization; see
|
||||||
|
// BlockRunner's ctor.
|
||||||
ProcessedNode(
|
ProcessedNode(
|
||||||
Node* n,
|
Node* n,
|
||||||
ProcessedFunction* fn,
|
ProcessedFunction* fn,
|
||||||
ProcessedNodeInputs inputs,
|
ProcessedNodeInputs inputs,
|
||||||
uint16_t outputs_offset);
|
uint16_t outputs_offset);
|
||||||
|
|
||||||
ProcessedNode(const ProcessedNode&) = default;
|
ProcessedNode(const ProcessedNode& other)
|
||||||
ProcessedNode& operator=(const ProcessedNode&) = default;
|
: node_(other.node_),
|
||||||
|
fn_(other.fn_),
|
||||||
|
overlap_detected_(other.overlap_detected_),
|
||||||
|
inputs_(other.inputs_),
|
||||||
|
outputs_offset_(other.outputs_offset_),
|
||||||
|
num_outputs_(other.num_outputs_),
|
||||||
|
values_(other.values_),
|
||||||
|
// It doesn't really make sense to copy block runners,
|
||||||
|
// each processed node needs its own. This is OK to do
|
||||||
|
// since ProcessedNodes are copied from StaticModule right before
|
||||||
|
// the block runners are set up.
|
||||||
|
// TODO(T105178680): For this task, we should move
|
||||||
|
// block runners out of ProcessedNode. Then, we don't have to deal
|
||||||
|
// with this caveat.
|
||||||
|
block_runners_(nullptr)
|
||||||
|
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
|
||||||
|
,
|
||||||
|
op_name_(other.op_name_)
|
||||||
|
#endif
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
ProcessedNode& operator=(const ProcessedNode& other) {
|
||||||
|
if (&other == this) {
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
node_ = other.node_;
|
||||||
|
fn_ = other.fn_;
|
||||||
|
overlap_detected_ = other.overlap_detected_;
|
||||||
|
inputs_ = other.inputs_;
|
||||||
|
outputs_offset_ = other.outputs_offset_;
|
||||||
|
num_outputs_ = other.num_outputs_;
|
||||||
|
values_ = other.values_;
|
||||||
|
block_runners_ = nullptr;
|
||||||
|
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
|
||||||
|
op_name_ = other.op_name_;
|
||||||
|
#endif
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
// These should be noexcept, but some Android build is failing
|
// These should be noexcept, but some Android build is failing
|
||||||
// saying the noexcept specification doesn't match the calculated
|
// saying the noexcept specification doesn't match the calculated
|
||||||
|
|
@ -732,11 +890,21 @@ class TORCH_API ProcessedNode {
|
||||||
}
|
}
|
||||||
|
|
||||||
C10_NODISCARD uint16_t output_ivalue_index(uint16_t i) const {
|
C10_NODISCARD uint16_t output_ivalue_index(uint16_t i) const {
|
||||||
|
DCHECK(i < num_outputs_);
|
||||||
return outputs_offset_ + i;
|
return outputs_offset_ + i;
|
||||||
}
|
}
|
||||||
// used in debug mode
|
// used in debug mode
|
||||||
bool verify_no_memory_overlap(bool force_check = false) const;
|
bool verify_no_memory_overlap(bool force_check = false) const;
|
||||||
|
|
||||||
|
std::vector<BlockRunner>* block_runners() {
|
||||||
|
return block_runners_.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_block_runners(
|
||||||
|
std::unique_ptr<std::vector<BlockRunner>> block_runners) {
|
||||||
|
block_runners_ = std::move(block_runners);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
C10_NODISCARD bool verify_outputs_dont_overlap_each_other() const;
|
C10_NODISCARD bool verify_outputs_dont_overlap_each_other() const;
|
||||||
|
|
||||||
|
|
@ -749,10 +917,74 @@ class TORCH_API ProcessedNode {
|
||||||
uint16_t outputs_offset_;
|
uint16_t outputs_offset_;
|
||||||
uint16_t num_outputs_;
|
uint16_t num_outputs_;
|
||||||
IValue* values_ = nullptr; // unowned
|
IValue* values_ = nullptr; // unowned
|
||||||
|
// For control flow; processed nodes may have sub-blocks which can
|
||||||
|
// be executed by op implementations.
|
||||||
|
std::unique_ptr<std::vector<BlockRunner>> block_runners_;
|
||||||
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
|
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
|
||||||
const char* op_name_;
|
const char* op_name_;
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// `StaticRuntime` is the owner of the array of IValues (used for constants,
|
||||||
|
// inputs, and intermediate tensors) that all `BlockRunner`s share.
|
||||||
|
// Upon construction, it initializes all block runners. `operator()` simply
|
||||||
|
// forwards the inputs to the top-level block runner. Each `StaticRuntime`
|
||||||
|
// instance corresponds to one `StaticModule`. Multiple `StaticRuntime`
|
||||||
|
// instances can be created; this is useful for multi-threaded execution, since
|
||||||
|
// `operator()` is not thread-safe.
|
||||||
|
class TORCH_API StaticRuntime {
|
||||||
|
public:
|
||||||
|
explicit StaticRuntime(const StaticModule& sm);
|
||||||
|
|
||||||
|
using KeywordArgs = std::unordered_map<std::string, c10::IValue>;
|
||||||
|
c10::IValue operator()(
|
||||||
|
const std::vector<c10::IValue>& args,
|
||||||
|
const KeywordArgs& kwargs = KeywordArgs());
|
||||||
|
c10::IValue operator()(
|
||||||
|
std::vector<c10::IValue>&& args,
|
||||||
|
const KeywordArgs& kwargs = KeywordArgs());
|
||||||
|
|
||||||
|
bool check_for_memory_leak(bool output_returned = true);
|
||||||
|
bool checkOutputTensorMemoryLeaks();
|
||||||
|
|
||||||
|
void deallocateOutputTensors();
|
||||||
|
bool isManagedOutputTensor(const IValue& ivalue) const;
|
||||||
|
void disableManageOutputTensors();
|
||||||
|
|
||||||
|
// Gets the top-level memory planner. Used for testing.
|
||||||
|
const MemoryPlanner* get_memory_planner() const;
|
||||||
|
|
||||||
|
void benchmark(
|
||||||
|
const std::vector<std::vector<c10::IValue>>& args_list,
|
||||||
|
const std::vector<KeywordArgs>& kwargs_list,
|
||||||
|
const int warmup_runs,
|
||||||
|
const int main_runs,
|
||||||
|
bool print_per_node_time = false,
|
||||||
|
bool generate_ai_pep_output = false) {
|
||||||
|
block_->benchmark(
|
||||||
|
args_list,
|
||||||
|
kwargs_list,
|
||||||
|
warmup_runs,
|
||||||
|
main_runs,
|
||||||
|
print_per_node_time,
|
||||||
|
generate_ai_pep_output);
|
||||||
|
}
|
||||||
|
|
||||||
|
using IndividualMetrics = BlockRunner::IndividualMetrics;
|
||||||
|
|
||||||
|
IndividualMetrics benchmark_individual_ops(
|
||||||
|
const std::vector<std::vector<c10::IValue>>& args_list,
|
||||||
|
const std::vector<KeywordArgs>& kwargs_list,
|
||||||
|
const int warmup_runs,
|
||||||
|
const int main_runs) {
|
||||||
|
return block_->benchmark_individual_ops(
|
||||||
|
args_list, kwargs_list, warmup_runs, main_runs);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<BlockRunner> block_;
|
||||||
|
std::vector<IValue> values_;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
||||||
|
|
@ -129,10 +129,10 @@ bool setIncludes(const FastSet<const Value*>& set, const Value* v) {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::pair<size_t, at::Tensor*>> assignStorageToOutputTensors(
|
std::vector<std::pair<size_t, at::Tensor*>> assignStorageToOutputTensors(
|
||||||
StaticRuntime* runtime,
|
BlockRunner* block_runner,
|
||||||
const FastSet<const Value*>& managed_output_tensor_values) {
|
const FastSet<const Value*>& managed_output_tensor_values) {
|
||||||
std::vector<std::pair<size_t, at::Tensor*>> managed_output_tensors;
|
std::vector<std::pair<size_t, at::Tensor*>> managed_output_tensors;
|
||||||
for (auto& pnode : runtime->nodes()) {
|
for (auto& pnode : block_runner->nodes()) {
|
||||||
for (const auto i : c10::irange(pnode.outputs().size())) {
|
for (const auto i : c10::irange(pnode.outputs().size())) {
|
||||||
auto& ival = pnode.Output(i);
|
auto& ival = pnode.Output(i);
|
||||||
const auto* val = pnode.node()->outputs()[i];
|
const auto* val = pnode.node()->outputs()[i];
|
||||||
|
|
@ -151,19 +151,20 @@ std::vector<std::pair<size_t, at::Tensor*>> assignStorageToOutputTensors(
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
MemoryPlanner::MemoryPlanner(
|
MemoryPlanner::MemoryPlanner(
|
||||||
StaticRuntime* runtime,
|
BlockRunner* block_runner,
|
||||||
const ValueGroup& value_group,
|
const BlockInfo& block_info,
|
||||||
const FastSet<const Value*>& managed_tensor_values,
|
|
||||||
const FastSet<const Value*>& managed_output_tensor_values,
|
|
||||||
const FastSet<const Value*>& leaked_values,
|
|
||||||
const ManagedTensorRanges& ranges,
|
|
||||||
bool enable_out_variant,
|
bool enable_out_variant,
|
||||||
bool manage_output_tensors,
|
bool manage_output_tensors,
|
||||||
bool optimize_memory) {
|
bool optimize_memory) {
|
||||||
|
const auto& managed_tensor_values = block_info.managed_tensor_values();
|
||||||
|
const auto& managed_output_tensor_values =
|
||||||
|
block_info.managed_output_tensor_values();
|
||||||
|
const auto& leaked_values = block_info.leaked_values();
|
||||||
|
|
||||||
// collect unmanaged output ivalues
|
// collect unmanaged output ivalues
|
||||||
FastSet<IValue*> unmanaged_ivalues;
|
FastSet<IValue*> unmanaged_ivalues;
|
||||||
FastSet<IValue*> unmanaged_borrowed_ivalues;
|
FastSet<IValue*> unmanaged_borrowed_ivalues;
|
||||||
for (ProcessedNode& pnode : runtime->nodes()) {
|
for (ProcessedNode& pnode : block_runner->nodes()) {
|
||||||
const auto borrows_outputs = borrowsOutputs(pnode.node()->kind());
|
const auto borrows_outputs = borrowsOutputs(pnode.node()->kind());
|
||||||
for (const auto i : c10::irange(pnode.outputs().size())) {
|
for (const auto i : c10::irange(pnode.outputs().size())) {
|
||||||
const Value* out_v = pnode.node()->outputs()[i];
|
const Value* out_v = pnode.node()->outputs()[i];
|
||||||
|
|
@ -189,7 +190,7 @@ MemoryPlanner::MemoryPlanner(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (IValue* output : runtime->outputs()) {
|
for (IValue* output : block_runner->outputs()) {
|
||||||
auto it = unmanaged_borrowed_ivalues.find(output);
|
auto it = unmanaged_borrowed_ivalues.find(output);
|
||||||
if (it != unmanaged_borrowed_ivalues.end()) {
|
if (it != unmanaged_borrowed_ivalues.end()) {
|
||||||
borrowed_ivalues_needing_incref_.push_back(output);
|
borrowed_ivalues_needing_incref_.push_back(output);
|
||||||
|
|
@ -213,10 +214,12 @@ MemoryPlanner::MemoryPlanner(
|
||||||
|
|
||||||
if (enable_out_variant) {
|
if (enable_out_variant) {
|
||||||
const auto tensor_value_to_tensor =
|
const auto tensor_value_to_tensor =
|
||||||
tensorValueToTensor(runtime->nodes(), managed_tensor_values);
|
tensorValueToTensor(block_runner->nodes(), managed_tensor_values);
|
||||||
if (optimize_memory) {
|
if (optimize_memory) {
|
||||||
managed_tensors_ = assignStorageToManagedTensors(
|
managed_tensors_ = assignStorageToManagedTensors(
|
||||||
runtime->node_ptrs(), ranges, tensor_value_to_tensor);
|
block_info.node_ptrs(),
|
||||||
|
block_info.managed_tensor_ranges(),
|
||||||
|
tensor_value_to_tensor);
|
||||||
} else {
|
} else {
|
||||||
for (auto& tensor : tensor_value_to_tensor) {
|
for (auto& tensor : tensor_value_to_tensor) {
|
||||||
managed_tensors_.emplace_back(tensor.second);
|
managed_tensors_.emplace_back(tensor.second);
|
||||||
|
|
@ -225,8 +228,8 @@ MemoryPlanner::MemoryPlanner(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (enable_out_variant && manage_output_tensors) {
|
if (enable_out_variant && manage_output_tensors) {
|
||||||
managed_output_tensors_ =
|
managed_output_tensors_ = assignStorageToOutputTensors(
|
||||||
assignStorageToOutputTensors(runtime, managed_output_tensor_values);
|
block_runner, managed_output_tensor_values);
|
||||||
}
|
}
|
||||||
|
|
||||||
num_managed_tensors_ = 0;
|
num_managed_tensors_ = 0;
|
||||||
|
|
|
||||||
|
|
@ -93,12 +93,8 @@ TORCH_API std::vector<StorageGroup> assignStorageToManagedTensors(
|
||||||
class MemoryPlanner {
|
class MemoryPlanner {
|
||||||
public:
|
public:
|
||||||
explicit MemoryPlanner(
|
explicit MemoryPlanner(
|
||||||
StaticRuntime* runtime,
|
BlockRunner* block_runner,
|
||||||
const ValueGroup& value_group,
|
const BlockInfo& block_info,
|
||||||
const FastSet<const Value*>& managed_tensor_values,
|
|
||||||
const FastSet<const Value*>& managed_output_tensor_values,
|
|
||||||
const FastSet<const Value*>& leaked_values,
|
|
||||||
const ManagedTensorRanges& ranges,
|
|
||||||
bool enable_out_variant,
|
bool enable_out_variant,
|
||||||
bool manage_output_tensors,
|
bool manage_output_tensors,
|
||||||
bool optimize_memory);
|
bool optimize_memory);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user