mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69274 `impl.h` is the main header file that defines the interface of Static Runtime to its clients. However, it is currently filled with implementation details that should not be leaked to our clients. 1) this can unnecessarily leak our internals to our clients which can make it hard to change them later 2) cause unnecessary merge conflicts when multiple people are touching this enormous impl.cpp file. To alleviate the situation, this change moves the implementation details from impl.h into a new file, internal.h, that's internally kept without leaking the details to our clients. This change will be followed by another change to rename `impl.h` into `runtime.h` or anything better since `impl.h` is currently not about implementation but SR's interface. Note that this change is NOT complete since the remaining declarations in impl.h still contain a lot of implementation details. Therefore, we should keep working on minimizing the interface to prevent our API from being bloated unnecessarily. Also we need to work on modularizing our implementations into separate pieces organized by separate files in the near future. Test Plan: Existing unittests Reviewed By: donaldong Differential Revision: D32780415 fbshipit-source-id: 119b7aedbf563b195641c5674572a9348732145f
1754 lines
59 KiB
C++
1754 lines
59 KiB
C++
#include <torch/csrc/jit/runtime/static/impl.h>
|
|
|
|
#include <ATen/MemoryOverlap.h>
|
|
#include <ATen/core/interned_strings.h>
|
|
#include <ATen/record_function.h>
|
|
#include <c10/core/CPUAllocator.h>
|
|
#include <c10/core/InferenceMode.h>
|
|
#include <c10/util/irange.h>
|
|
#include <caffe2/core/scope_guard.h>
|
|
#include <caffe2/core/timer.h>
|
|
#include <torch/csrc/jit/ir/alias_analysis.h>
|
|
#include <torch/csrc/jit/jit_log.h>
|
|
#include <torch/csrc/jit/passes/canonicalize.h>
|
|
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
|
#include <torch/csrc/jit/passes/eliminate_no_ops.h>
|
|
#include <torch/csrc/jit/passes/freeze_module.h>
|
|
#include <torch/csrc/jit/passes/remove_mutation.h>
|
|
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
|
#include <torch/csrc/jit/passes/variadic_ops.h>
|
|
#include <torch/csrc/jit/runtime/static/internal.h>
|
|
#include <torch/csrc/jit/runtime/static/memory_planner.h>
|
|
#include <torch/csrc/jit/runtime/static/ops.h>
|
|
#include <torch/csrc/jit/runtime/static/passes.h>
|
|
#include <torch/csrc/jit/runtime/vararg_functions.h>
|
|
#include <iterator>
|
|
#include <sstream>
|
|
#include <stdexcept>
|
|
|
|
#ifdef FBCODE_CAFFE2
|
|
#include <folly/dynamic.h>
|
|
#include <folly/json.h>
|
|
#endif
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
// A manually curated set of ops that are disallowed in static runtime.
|
|
// These are rarely-used ops. Disallowing them typically eliminates
|
|
// corner cases in graph optimizations, allowing for more aggressive
|
|
// optimizations and better performance.
|
|
bool isUnsupportedOp(const NodeKind& kind) {
|
|
return kind == aten::__is__ || kind == aten::__isnot__;
|
|
}
|
|
|
|
// graph must be frozen or canEnableStaticRuntime would return false if there's
|
|
// any prim::CallMethod op left in the graph
|
|
bool canEnableStaticRuntime(const std::shared_ptr<torch::jit::Graph>& graph) {
|
|
// check for sub-blocks
|
|
bool can_support = true;
|
|
bool has_blocks = false;
|
|
for (auto* node : graph->block()->nodes()) {
|
|
if (node->blocks().size() > 0) {
|
|
has_blocks = true;
|
|
VLOG(1) << "Found nested sub-blocks in graph at node: "
|
|
<< PrintNode(node);
|
|
}
|
|
const auto kind = node->kind();
|
|
if (kind == prim::Constant) {
|
|
continue;
|
|
}
|
|
// check if can get op from Node
|
|
const Operator* op = node->maybeOperator();
|
|
if (isUnsupportedOp(kind) || (!op && !nativeOpIsRegistered(kind))) {
|
|
can_support = false;
|
|
LOG(WARNING) << "Found unsupported op: " << kind.toQualString();
|
|
}
|
|
}
|
|
if (has_blocks) {
|
|
LOG(WARNING)
|
|
<< "Found nested sub-block in graph. Static Runtime doesn't support nested sub-blocks.";
|
|
can_support = false;
|
|
}
|
|
return can_support;
|
|
}
|
|
|
|
namespace {
|
|
|
|
void OptimizeGraph(
|
|
std::shared_ptr<torch::jit::Graph>& graph,
|
|
const StaticModuleOptions& opts) {
|
|
GRAPH_DUMP("Before optimizations: ", graph);
|
|
Inline(*graph);
|
|
ConstantPropagation(graph);
|
|
Canonicalize(graph);
|
|
ConstantPropagation(graph);
|
|
RemoveTensorMutation(graph);
|
|
ConstantPropagation(graph);
|
|
EliminateDeadCode(graph);
|
|
FuseInferenceOpsForSparseNN(graph);
|
|
UseVariadicCat(graph);
|
|
UseVariadicStack(graph);
|
|
EliminateTrivialEquallySplit(graph);
|
|
|
|
if (opts.enable_out_variant) {
|
|
UseVariadicOp(
|
|
graph,
|
|
fromQualString("fb::sigrid_transforms_torch_bind"),
|
|
fromQualString("fb::variadic_sigrid_transforms_torch_bind"));
|
|
FuseSignLog1P(graph);
|
|
|
|
// TODO: we can avoid this guard by moving operations
|
|
// to exposed folders.
|
|
#ifdef FBCODE_CAFFE2
|
|
ReplaceWithCopy(graph);
|
|
FuseListUnpack(graph);
|
|
EnableStaticRuntimeLayerNorm(graph);
|
|
#endif
|
|
}
|
|
|
|
ConstantPropagation(graph);
|
|
RemoveImmutableInputDictLookups(graph);
|
|
UseVariadicTupleUnpack(graph);
|
|
UseVariadicGroupedAccessor(graph);
|
|
EliminateNoOps(
|
|
graph, /* custom_ops */ {fromQualString("fb::scale_gradient")});
|
|
GRAPH_DUMP("Final graph after optimizations: ", graph);
|
|
}
|
|
|
|
// remove unused input 0 from graph
|
|
bool RemoveSelfFromGraphInput(std::shared_ptr<torch::jit::Graph>& graph) {
|
|
if (graph->inputs().at(0)->type()->is_module()) {
|
|
if (graph->inputs().at(0)->hasUses()) {
|
|
return false;
|
|
}
|
|
graph->eraseInput(0);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
// remove "self" from function schema
|
|
c10::FunctionSchema RemoveSelfFromSchema(const c10::FunctionSchema& s) {
|
|
TORCH_CHECK(s.arguments().size() >= 1 && s.arguments()[0].name() == "self");
|
|
std::vector<Argument> args({s.arguments().begin() + 1, s.arguments().end()});
|
|
return s.cloneWithArguments(args);
|
|
}
|
|
|
|
// Map each value to all values that are alive at the same time.
|
|
using LivenessMap = FastMap<const Value*, FastSet<const Value*>>;
|
|
|
|
template <typename Map>
|
|
std::string dumpMapFromValuesToListsOrSetsOfOtherValues(const Map& map) {
|
|
std::ostringstream oss;
|
|
oss << "{";
|
|
for (const auto& p : map) {
|
|
oss << "{%" << p.first->debugName() << ": {";
|
|
for (const auto* val : p.second) {
|
|
oss << "%" << val->debugName() << ", ";
|
|
}
|
|
oss << "}},\n";
|
|
}
|
|
oss << "}";
|
|
return oss.str();
|
|
}
|
|
|
|
std::string dumpLivenessMap(const LivenessMap& liveness_map) {
|
|
return dumpMapFromValuesToListsOrSetsOfOtherValues(liveness_map);
|
|
};
|
|
|
|
// The algorithm does a traversal of the execution graph
|
|
// while keeping track of the live values.
|
|
LivenessMap GetLivenessMap(
|
|
const std::shared_ptr<torch::jit::Graph>& graph,
|
|
const ValueGroup& value_group,
|
|
AliasDb& db) {
|
|
// map a Value to a set of Values that overlap live-ranges with the Value's
|
|
FastMap<const Value*, FastSet<const Value*>> liveness_map;
|
|
|
|
// map Values to its creation order in graph (Note: only traverse top-level
|
|
// nodes such that nodes under control-flows are represented by top-level
|
|
// block nodes)
|
|
std::vector<const Value*> values_in_creation_order;
|
|
FastMap<const Value*, size_t> values_to_idx_in_creation_order;
|
|
|
|
// presence of a Value in live_values_use_chain means the Value alive
|
|
// Value mapped to set of Nodes that may use the Value (i.e., use-chain of
|
|
// Value)
|
|
FastMap<const Value*, FastSet<const Node*>> live_values_use_chain;
|
|
// Node mapped to set of Values that the Node may use (i.e., def-chain of node
|
|
// inputs)
|
|
FastMap<const Node*, FastSet<const Value*>> live_nodes_def_chain;
|
|
|
|
{
|
|
// Set container capacity
|
|
size_t live_values_size = 0, live_nodes_size = 0,
|
|
values_in_creation_size = 0;
|
|
for (const auto* node : graph->nodes()) {
|
|
bool has_live_value = false;
|
|
for (const auto* v : node->outputs()) {
|
|
++values_in_creation_size;
|
|
if (!value_group.isAlwaysAlive(v)) {
|
|
++live_values_size;
|
|
has_live_value = true;
|
|
}
|
|
}
|
|
// All inputs and ouputs should be alive at the same time.
|
|
live_values_size += node->inputs().size();
|
|
if (has_live_value) {
|
|
++live_nodes_size;
|
|
}
|
|
}
|
|
|
|
live_nodes_def_chain.reserve(live_nodes_size);
|
|
live_values_use_chain.reserve(live_values_size);
|
|
liveness_map.reserve(live_values_size);
|
|
values_in_creation_order.reserve(values_in_creation_size);
|
|
values_to_idx_in_creation_order.reserve(values_in_creation_size);
|
|
}
|
|
|
|
// Construct values_in_creation
|
|
for (const auto* node : graph->nodes()) {
|
|
for (const auto* v : node->outputs()) {
|
|
values_to_idx_in_creation_order.emplace(
|
|
v, values_in_creation_order.size());
|
|
values_in_creation_order.emplace_back(v);
|
|
}
|
|
}
|
|
|
|
// add v to the current liveness_map
|
|
std::function<void(const Value* v)> add_live_value_fn = [&](const Value* v) {
|
|
if (liveness_map.count(v)) {
|
|
return;
|
|
}
|
|
|
|
auto& v_live_set = liveness_map[v] = {};
|
|
|
|
v_live_set.reserve(live_values_use_chain.size());
|
|
for (const auto& live_v : live_values_use_chain) {
|
|
v_live_set.insert(live_v.first);
|
|
liveness_map[live_v.first].insert(v);
|
|
}
|
|
|
|
// only add values to the live set if they
|
|
// have deps, otherwise they die immediately
|
|
if (v->uses().size()) {
|
|
live_values_use_chain[v] = FastSet<const Node*>(v->uses().size());
|
|
// record the relationship between v (Value) and its uses (Node)
|
|
for (const auto& u : v->uses()) {
|
|
const auto* node = u.user;
|
|
live_values_use_chain[v].insert(node);
|
|
live_nodes_def_chain[node].insert(v);
|
|
}
|
|
}
|
|
|
|
// FIXME(penguin): the following alias refinement seems to assume
|
|
// that `v` refers to a new tensor created by the node that defines
|
|
// v, thus other Values "before" the node that defines `v` cannot
|
|
// possibly be aliased to `v`.
|
|
// TODO(penguin): Is it a limitation of TS alias analysis
|
|
// so that we need to do such refinement? If so, better improve
|
|
// alias analysis so that we dont need this special handling here
|
|
//
|
|
// Refine aliases of v by include only those created after v
|
|
std::vector<const Value*> refined_aliases;
|
|
auto idx = values_to_idx_in_creation_order[v];
|
|
for (; idx < values_in_creation_order.size(); ++idx) {
|
|
auto* alias_v = values_in_creation_order[idx];
|
|
if (mayContainAlias(db, v, alias_v)) {
|
|
refined_aliases.emplace_back(alias_v);
|
|
}
|
|
}
|
|
// for all the values in the alias set,
|
|
// we set them "alive"
|
|
for (auto* aliased_v : refined_aliases) {
|
|
GRAPH_DEBUG(
|
|
"aliased_v: %",
|
|
aliased_v->debugName(),
|
|
" (for %",
|
|
v->debugName(),
|
|
")");
|
|
add_live_value_fn(aliased_v);
|
|
}
|
|
};
|
|
|
|
auto remove_dead_values = [&](const Node* node) {
|
|
auto find = live_nodes_def_chain.find(node);
|
|
if (find != live_nodes_def_chain.end()) {
|
|
for (const auto* v : find->second) {
|
|
live_values_use_chain[v].erase(node);
|
|
if (!live_values_use_chain[v].size()) {
|
|
// v is now dead
|
|
GRAPH_DEBUG(
|
|
"%",
|
|
v->debugName(),
|
|
" is now dead after ",
|
|
node->output(0)->debugName())
|
|
live_values_use_chain.erase(v);
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
for (const auto* node : graph->nodes()) {
|
|
for (const auto* v : node->outputs()) {
|
|
if (!value_group.isAlwaysAlive(v)) {
|
|
add_live_value_fn(v);
|
|
}
|
|
}
|
|
|
|
remove_dead_values(node);
|
|
}
|
|
GRAPH_DEBUG("LivenessMap: ", dumpLivenessMap(liveness_map));
|
|
|
|
for (const auto& v : live_values_use_chain) {
|
|
TORCH_CHECK(
|
|
value_group.isAlwaysAlive(v.first),
|
|
v.first->debugName(),
|
|
"is not in the value_group.isAlwaysAlive group");
|
|
}
|
|
|
|
auto insert_all_pairs_in_liveness_map =
|
|
[&](at::ArrayRef<const Value*> values) {
|
|
for (size_t i = 0; !values.empty() && i < values.size() - 1; ++i) {
|
|
auto value_it = liveness_map.find(values[i]);
|
|
if (value_it == liveness_map.end()) {
|
|
continue;
|
|
}
|
|
for (size_t j = i + 1; j < values.size(); ++j) {
|
|
auto value2_it = liveness_map.find(values[j]);
|
|
if (value2_it != liveness_map.end()) {
|
|
value_it->second.insert(values[j]);
|
|
value2_it->second.insert(values[i]);
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
for (const auto* node : graph->nodes()) {
|
|
auto inputs = node->inputs();
|
|
auto outputs = node->outputs();
|
|
for (const auto* input : inputs) {
|
|
for (const auto* output : outputs) {
|
|
auto input_it = liveness_map.find(input);
|
|
if (input_it == liveness_map.end()) {
|
|
continue;
|
|
}
|
|
auto output_it = liveness_map.find(output);
|
|
if (output_it == liveness_map.end()) {
|
|
continue;
|
|
}
|
|
input_it->second.insert(output);
|
|
output_it->second.insert(input);
|
|
}
|
|
}
|
|
|
|
// All inputs should be alive at the same time.
|
|
insert_all_pairs_in_liveness_map(inputs);
|
|
|
|
// All outputs should be alive at the same time.
|
|
insert_all_pairs_in_liveness_map(outputs);
|
|
};
|
|
|
|
return liveness_map;
|
|
};
|
|
|
|
// Collect the set of Values that are candidates for memory planning:
|
|
// - Values that are used in in-place operators (i.e., _out variants), and
|
|
// - excluding those that are either inputs or outputs of
|
|
// non in-place operators
|
|
//
|
|
// Returns
|
|
// first: Values that are candidates for memory planning
|
|
// second: A deterministc order of all values
|
|
std::pair<std::vector<const Value*>, std::vector<const Value*>>
|
|
GetMemoryPlanningCandidates(
|
|
const std::shared_ptr<torch::jit::Graph>& graph,
|
|
const FastMap<Node*, bool>& node_has_out_variant) {
|
|
// for determinism
|
|
FastSet<const Value*> seen_values;
|
|
std::vector<const Value*> all_values;
|
|
FastSet<const Value*> can_reuse;
|
|
// values used by unsupported ops (as either inputs or outputs)
|
|
// these need to be removed from "can_reuse" after analyzing all nodes
|
|
FastSet<const Value*> cannot_reuse;
|
|
for (auto* n : graph->nodes()) {
|
|
bool can_reuse_inputs_outputs =
|
|
canReuseInputsOutputs(n, node_has_out_variant);
|
|
for (const auto* v : n->inputs()) {
|
|
if (!seen_values.count(v)) {
|
|
all_values.emplace_back(v);
|
|
seen_values.insert(v);
|
|
}
|
|
if (can_reuse_inputs_outputs) {
|
|
can_reuse.insert(v);
|
|
} else {
|
|
cannot_reuse.insert(v);
|
|
}
|
|
}
|
|
for (const auto* v : n->outputs()) {
|
|
all_values.emplace_back(v);
|
|
seen_values.insert(v);
|
|
if (can_reuse_inputs_outputs) {
|
|
can_reuse.insert(v);
|
|
} else {
|
|
cannot_reuse.insert(v);
|
|
}
|
|
}
|
|
}
|
|
for (const auto* v : cannot_reuse) {
|
|
can_reuse.erase(v);
|
|
}
|
|
// find a deterministic order
|
|
std::vector<const Value*> optimizable;
|
|
for (const auto* v : all_values) {
|
|
if (can_reuse.count(v)) {
|
|
optimizable.emplace_back(v);
|
|
can_reuse.erase(v);
|
|
}
|
|
}
|
|
return std::make_pair(optimizable, all_values);
|
|
}
|
|
|
|
// Equipped with a liveness map we can allocate memory to
|
|
// ivalues, reusing memory along the way. However, we are
|
|
// constrained by the set of optimizable_values
|
|
// (inputs/outputs of out variants). Inputs/outputs of view ops
|
|
// can't be reused.
|
|
//
|
|
// Algorithm:
|
|
// # clusters of values sharing the same memory
|
|
// # are called "value_to_same_storage_values" in the implementation
|
|
// # inserting into a cluster denotes sharing memory.
|
|
//
|
|
// clusters = {}
|
|
// for all v in optimzable_values:
|
|
// for all cluster in clusters: # can we insert into cluster?
|
|
// for all live_v in live_during(v):
|
|
// if cluster.contains(live_v):
|
|
// skip to next custer
|
|
// cluster.add(v)
|
|
// skip to next v
|
|
// if no cluster found:
|
|
// clusters.add(cluster{v})
|
|
//
|
|
//
|
|
// NB: This is a deterministic implementation, which makes it easier to tune
|
|
// and debug.
|
|
FastMap<const Value*, std::vector<const Value*>> GenerateSameStorageValues(
|
|
const LivenessMap& alive_during,
|
|
const ValueGroup& value_group,
|
|
const std::pair<std::vector<const Value*>, std::vector<const Value*>>&
|
|
optimizable,
|
|
AliasDb& db) {
|
|
const auto& optimizable_values = optimizable.first;
|
|
const auto& all_values = optimizable.second;
|
|
|
|
// map Value* to a set Value* that can share the same storage with it
|
|
FastMap<const Value*, std::vector<const Value*>> same_storage_values;
|
|
|
|
// make new_v and old_v map to the same storage (i.e., add to each other's
|
|
// same_storage_values set)
|
|
auto share_storage_fn = [&](const Value* new_v, const Value* old_v) {
|
|
if (new_v == old_v) {
|
|
return;
|
|
}
|
|
DCHECK(same_storage_values.count(old_v));
|
|
FastSet<const Value*> seen;
|
|
std::vector<const Value*> values;
|
|
for (auto* v : same_storage_values.at(old_v)) {
|
|
if (seen.count(v)) {
|
|
continue;
|
|
}
|
|
seen.insert(v);
|
|
values.emplace_back(v);
|
|
}
|
|
for (auto* v : same_storage_values.at(new_v)) {
|
|
if (seen.count(v)) {
|
|
continue;
|
|
}
|
|
seen.insert(v);
|
|
values.emplace_back(v);
|
|
}
|
|
for (const auto* v : values) {
|
|
same_storage_values[v] = values;
|
|
}
|
|
};
|
|
|
|
// initialize with known same_storage_values (aliasing values)
|
|
for (const auto* v : all_values) {
|
|
if (!same_storage_values.count(v)) {
|
|
same_storage_values[v] = {v};
|
|
}
|
|
// NOTE: if we had AliasDb::mustAlias, we could do the following:
|
|
// // skip always alive values (alias inputs/outputs/weights)
|
|
// if (value_group.isAlwaysAlive(v)) {
|
|
// continue;
|
|
// }
|
|
// for (const auto& p : same_storage_values) {
|
|
// if (db.mustAlias(p.first, v)) {
|
|
// share_storage_fn(v, p.first);
|
|
// }
|
|
// }
|
|
// It also wouldn't matter because ops always create new Tensor
|
|
// objects as aliases; there is no point in trying to reuse their
|
|
// storage.
|
|
}
|
|
|
|
// to preserve determinism
|
|
std::vector<const Value*> seen;
|
|
|
|
auto compute_liveset_fn = [&alive_during, &same_storage_values](
|
|
FastSet<const Value*>& live, const Value* v) {
|
|
for (const auto* sv : same_storage_values.at(v)) {
|
|
const auto& l = alive_during.count(sv) ? alive_during.at(sv)
|
|
: FastSet<const Value*>{};
|
|
live.insert(l.begin(), l.end());
|
|
}
|
|
};
|
|
|
|
// check if same_storage_values[s] intersects with live
|
|
auto intersect_fn = [&same_storage_values](
|
|
FastSet<const Value*>& live, const Value* s) {
|
|
bool intersect = false;
|
|
for (const auto* v : same_storage_values.at(s)) {
|
|
if (live.count(v)) {
|
|
intersect = true;
|
|
break;
|
|
}
|
|
}
|
|
return intersect;
|
|
};
|
|
|
|
for (const auto* v : optimizable_values) {
|
|
if (value_group.isAlwaysAlive(v)) {
|
|
continue;
|
|
}
|
|
// get values that are live during the lifetime of v
|
|
FastSet<const Value*> live;
|
|
compute_liveset_fn(live, v);
|
|
for (const auto* s : seen) {
|
|
// if live(same_storage_values[v]) and same_storage_values[s]
|
|
// do not overlap, then s and v can share the same storage
|
|
if (!intersect_fn(live, s) && !value_group.isAlwaysAlive(s)) {
|
|
share_storage_fn(v, s);
|
|
// since s is added to same_storage_values[v], live needs
|
|
// to be recomputed, so bail out here
|
|
break;
|
|
}
|
|
}
|
|
seen.emplace_back(v);
|
|
}
|
|
|
|
GRAPH_DEBUG(
|
|
"same_storage_values: ",
|
|
dumpMapFromValuesToListsOrSetsOfOtherValues(same_storage_values));
|
|
|
|
return same_storage_values;
|
|
}
|
|
|
|
void PrepareGraphForStaticModule(
|
|
std::shared_ptr<torch::jit::Graph> graph,
|
|
const StaticModuleOptions& opts) {
|
|
TORCH_CHECK(canEnableStaticRuntime(graph));
|
|
OptimizeGraph(graph, opts);
|
|
}
|
|
|
|
std::pair<std::shared_ptr<Graph>, c10::optional<Module>> PrepareForStaticModule(
|
|
const torch::jit::Module& m,
|
|
bool is_frozen,
|
|
const StaticModuleOptions& opts) {
|
|
VLOG(1) << "StaticModuleOptions: cleanup_activations "
|
|
<< opts.cleanup_activations << ", enable_out_variant "
|
|
<< opts.enable_out_variant << ", optimize_memory "
|
|
<< opts.optimize_memory << ", manage_output_tensors "
|
|
<< opts.manage_output_tensors;
|
|
|
|
Module module = m.copy();
|
|
if (!is_frozen) {
|
|
module.eval();
|
|
module = freeze_module(module);
|
|
}
|
|
|
|
Method method = module.get_method("forward");
|
|
auto graph = module.get_method("forward").graph();
|
|
|
|
PrepareGraphForStaticModule(graph, opts);
|
|
|
|
return std::make_pair(graph, module);
|
|
}
|
|
|
|
std::pair<std::shared_ptr<Graph>, c10::optional<Module>> PrepareForStaticModule(
|
|
std::shared_ptr<torch::jit::Graph> graph,
|
|
const StaticModuleOptions& opts) {
|
|
PrepareGraphForStaticModule(graph, opts);
|
|
return std::make_pair(graph, c10::nullopt);
|
|
}
|
|
|
|
bool containTensorsOnly(at::ArrayRef<Value*> values) {
|
|
// return true only if all outputs are tensors
|
|
return std::all_of(values.begin(), values.end(), [](const Value* value) {
|
|
return value->type()->castRaw<TensorType>() != nullptr;
|
|
});
|
|
}
|
|
|
|
} // namespace
|
|
|
|
StaticModule::StaticModule(
|
|
std::shared_ptr<torch::jit::Graph> g,
|
|
const StaticModuleOptions& opts)
|
|
: StaticModule(PrepareForStaticModule(g->copy(), opts), opts) {}
|
|
|
|
StaticModule::StaticModule(
|
|
const torch::jit::Module& m,
|
|
bool is_frozen,
|
|
const StaticModuleOptions& opts)
|
|
: StaticModule(PrepareForStaticModule(m, is_frozen, opts), opts) {}
|
|
|
|
StaticModule::StaticModule(
|
|
std::pair<std::shared_ptr<torch::jit::Graph>, c10::optional<Module>>
|
|
graph_and_module,
|
|
const StaticModuleOptions& opts)
|
|
: opts_(opts),
|
|
graph_(std::move(graph_and_module.first)),
|
|
module_(std::move(graph_and_module.second)) {
|
|
// check opt flags
|
|
if (opts.manage_output_tensors) {
|
|
TORCH_CHECK(
|
|
opts_.enable_out_variant,
|
|
"When manage_output_tensors is true, enable_out_variant must be set to true");
|
|
}
|
|
if (opts_.optimize_memory) {
|
|
TORCH_CHECK(
|
|
opts_.enable_out_variant,
|
|
"When optimize_memory is true, enable_out_variant must be set to true");
|
|
}
|
|
|
|
// handle schema
|
|
if (module_.has_value()) {
|
|
Method method = module_->get_method("forward");
|
|
if (RemoveSelfFromGraphInput(graph_)) {
|
|
schema_ = RemoveSelfFromSchema(method.function().getSchema());
|
|
module_ = c10::nullopt;
|
|
} else {
|
|
schema_ = method.function().getSchema();
|
|
}
|
|
}
|
|
|
|
// map Value* to its SSA definition IR
|
|
FastMap<Value*, DefInfo> value_to_ssa_def;
|
|
|
|
// N inputs map to the first N entries in storage
|
|
for (const auto i : c10::irange(graph_->inputs().size())) {
|
|
Value* input = graph_->inputs()[i];
|
|
value_to_ssa_def[input] = std::make_pair(INPUT_VALUE, i);
|
|
}
|
|
|
|
// NB: before optimizing the order of execution, ensure that the
|
|
// memory optimization pass (LivenessMap) is
|
|
// aware of the new order!
|
|
|
|
{
|
|
size_t nodes_size = 0, constants_size = 0;
|
|
for (Node* node : graph_->nodes()) {
|
|
++(node->kind() == prim::Constant ? constants_size : nodes_size);
|
|
}
|
|
|
|
constants_.reserve(constants_size);
|
|
functions_.reserve(nodes_size);
|
|
nodes_.reserve(nodes_size);
|
|
}
|
|
|
|
// Create ProcessedFunction instances first to freeze their addresses to pass
|
|
// to ProcessedNode.
|
|
AliasDb alias_db(
|
|
graph_, /*isFrozen=*/false, /*enablePreciseTupleContainerAnalysis=*/true);
|
|
GRAPH_DEBUG("AliasDb: ", alias_db.toString());
|
|
|
|
// Construct constant and function nodes
|
|
for (Node* node : graph_->nodes()) {
|
|
if (node->kind() == prim::Constant) {
|
|
auto* v = node->output();
|
|
TORCH_CHECK(v->type()->kind() != FunctionType::Kind);
|
|
// construct SSA definition for constant nodes
|
|
value_to_ssa_def[v] = std::make_pair(CONSTANT_VALUE, constants_.size());
|
|
constants_.emplace_back(toIValue(v).value());
|
|
continue;
|
|
}
|
|
|
|
// see [Check and correct bad schema alias info at runtime]
|
|
bool check_outputs_for_overlap =
|
|
!alias_db.mayContainAlias(node->inputs(), node->outputs()) &&
|
|
containTensorsOnly(node->outputs());
|
|
// new ProcessedFunction
|
|
functions_.emplace_back(
|
|
node, opts.enable_out_variant, check_outputs_for_overlap);
|
|
}
|
|
|
|
// construct SSA definition for non-constant nodes
|
|
int node_idx = 0;
|
|
FastMap<Node*, bool> node_has_out_variant;
|
|
|
|
const auto inputs_index_offset = 0;
|
|
const auto constants_index_offset = inputs_index_offset + num_inputs();
|
|
const auto values_index_offset = constants_index_offset + constants().size();
|
|
|
|
// Map node_idx to index offset in values_. Can't reserve space
|
|
// because we don't know how many non-constant nodes there are yet.
|
|
std::vector<uint32_t> node_output_idx_map;
|
|
uint32_t node_outputs_seen_so_far = 0;
|
|
for (Node* node : graph_->nodes()) {
|
|
if (node->kind() == prim::Constant) {
|
|
continue;
|
|
}
|
|
// Assign memory for the outputs
|
|
const auto outputs_offset_for_node =
|
|
node_outputs_seen_so_far + values_index_offset;
|
|
TORCH_CHECK(
|
|
outputs_offset_for_node < (1 << 16),
|
|
"outputs offset in values table",
|
|
outputs_offset_for_node,
|
|
" would overflow 2-byte index storage");
|
|
node_output_idx_map.push_back(outputs_offset_for_node);
|
|
node_outputs_seen_so_far += node->outputs().size();
|
|
}
|
|
|
|
for (Node* node : graph_->nodes()) {
|
|
if (node->kind() == prim::Constant) {
|
|
continue;
|
|
}
|
|
ProcessedNodeInputs input_indices(node->inputs().size());
|
|
std::vector<DefInfo> input_ssa_defs;
|
|
for (const auto input_idx : c10::irange(node->inputs().size())) {
|
|
Value* const input = node->inputs()[input_idx];
|
|
int inner_node_idx = 0;
|
|
int out_idx = 0;
|
|
std::tie(inner_node_idx, out_idx) = value_to_ssa_def.at(input);
|
|
unsigned int input_ivalue_idx = 0;
|
|
if (inner_node_idx == StaticModule::INPUT_VALUE) {
|
|
input_ivalue_idx = out_idx + inputs_index_offset;
|
|
} else if (inner_node_idx == StaticModule::CONSTANT_VALUE) {
|
|
input_ivalue_idx = out_idx + constants_index_offset;
|
|
} else {
|
|
DCHECK_GE(inner_node_idx, 0);
|
|
const auto global_value_idx =
|
|
node_output_idx_map[inner_node_idx] + out_idx;
|
|
if (inner_node_idx < node_output_idx_map.size() - 1) {
|
|
DCHECK_LT(global_value_idx, node_output_idx_map[inner_node_idx + 1]);
|
|
} else {
|
|
DCHECK_LT(
|
|
global_value_idx,
|
|
constants_index_offset + node_outputs_seen_so_far);
|
|
}
|
|
input_ivalue_idx = global_value_idx;
|
|
}
|
|
TORCH_CHECK(
|
|
input_ivalue_idx < (1 << 16),
|
|
"input index in values table ",
|
|
input_ivalue_idx,
|
|
" would overflow 2-byte index storage");
|
|
input_indices[input_idx] = input_ivalue_idx;
|
|
}
|
|
|
|
ProcessedFunction* fn = &functions_[node_idx];
|
|
// create a new ProcessedNode
|
|
// see [Check and correct bad schema alias info at runtime]
|
|
bool check_outputs_for_overlap =
|
|
!alias_db.mayContainAlias(node->inputs(), node->outputs()) &&
|
|
containTensorsOnly(node->outputs());
|
|
nodes_.emplace_back(
|
|
node, fn, std::move(input_indices), node_output_idx_map[node_idx]);
|
|
|
|
node_has_out_variant.emplace(node, nodes_.back().has_out_variant());
|
|
for (const auto i : c10::irange(node->outputs().size())) {
|
|
value_to_ssa_def[node->outputs()[i]] = std::make_pair(node_idx, i);
|
|
}
|
|
node_idx++;
|
|
}
|
|
for (auto& pnode : nodes_) {
|
|
if (pnode.num_outputs() == 1 &&
|
|
isOptimizableContainerType(pnode.node(), node_has_out_variant)) {
|
|
node_is_optimizable_container_type_.emplace(pnode.node());
|
|
}
|
|
}
|
|
output_indices_.reserve(graph_->outputs().size());
|
|
for (auto output : graph_->outputs()) {
|
|
int node_idx = 0;
|
|
int out_idx = 0;
|
|
std::tie(node_idx, out_idx) = value_to_ssa_def[output];
|
|
uint32_t output_index = 0;
|
|
if (node_idx == StaticModule::INPUT_VALUE) {
|
|
output_index = out_idx + inputs_index_offset;
|
|
} else if (node_idx == StaticModule::CONSTANT_VALUE) {
|
|
output_index = constants_index_offset + out_idx;
|
|
} else {
|
|
output_index = nodes_[node_idx].output_ivalue_index(out_idx);
|
|
}
|
|
TORCH_CHECK(
|
|
output_index < (1 << 16),
|
|
"output index ",
|
|
output_index,
|
|
" would overflow 2-byte index storage");
|
|
output_indices_.emplace_back(output_index);
|
|
}
|
|
|
|
// Prepare for memory planning
|
|
value_group_ = std::make_unique<ValueGroup>(graph_, alias_db);
|
|
GRAPH_DEBUG(value_group_->toString());
|
|
|
|
if (opts_.optimize_memory) {
|
|
auto lm = GetLivenessMap(graph_, *value_group_, alias_db);
|
|
auto values = GetMemoryPlanningCandidates(graph_, node_has_out_variant);
|
|
value_to_same_storage_values_ =
|
|
GenerateSameStorageValues(lm, *value_group_, values, alias_db);
|
|
}
|
|
|
|
prepareForMemoryPlanner();
|
|
}
|
|
|
|
void StaticModule::prepareForMemoryPlanner() {
|
|
if (!opts_.enable_out_variant) {
|
|
return;
|
|
}
|
|
|
|
// Never manage graph outputs so that we can do std::move(output_ivalue).
|
|
// This does not affect performance if the graph returns a collection object.
|
|
FastSet<const Value*> graph_output_values(
|
|
graph_->outputs().begin(), graph_->outputs().end());
|
|
|
|
// collect register indices of outputs of ops with out variant
|
|
for (ProcessedNode& pnode : nodes_) {
|
|
if (!pnode.has_out_variant()) {
|
|
continue;
|
|
}
|
|
auto outputs = pnode.node()->outputs();
|
|
for (const auto i : c10::irange(outputs.size())) {
|
|
const Value* out_v = outputs[i];
|
|
// Types are stored in the underlying TorchScript IR
|
|
bool is_tensor_type = out_v->type()->castRaw<TensorType>();
|
|
if (opts_.manage_output_tensors && is_tensor_type &&
|
|
graph_output_values.find(out_v) == graph_output_values.end() &&
|
|
value_group_->isOutputAlias(out_v)) {
|
|
managed_output_tensor_values_.insert(out_v);
|
|
continue;
|
|
}
|
|
if (value_group_->isAlwaysAlive(out_v)) {
|
|
continue;
|
|
}
|
|
if (is_tensor_type) {
|
|
managed_tensor_values_.insert(out_v);
|
|
} else if (is_optimizable_container_type(pnode.node())) {
|
|
// We "leak" certain container types because their allocations
|
|
// take a long time
|
|
leaked_values_.insert(out_v);
|
|
}
|
|
}
|
|
}
|
|
|
|
for (const Value* output : graph_->outputs()) {
|
|
managed_tensor_values_.erase(output);
|
|
}
|
|
GRAPH_DEBUG("managed_tensor_values: ", dumpValueSet(managed_tensor_values_));
|
|
GRAPH_DEBUG(
|
|
"managed_output_tensor_values_: ",
|
|
dumpValueSet(managed_output_tensor_values_));
|
|
}
|
|
|
|
// These are needed to be defined in.cpp file since they require
|
|
// concrete types for ProcessedFunction/ProcessedNode.
|
|
StaticModule::StaticModule(torch::jit::StaticModule&&) noexcept {}
|
|
StaticModule::~StaticModule() = default;
|
|
|
|
const StaticModuleOptions& StaticModule::opts() const {
|
|
return opts_;
|
|
}
|
|
|
|
size_t StaticModule::num_outputs() const {
|
|
return graph_->outputs().size();
|
|
}
|
|
|
|
size_t StaticModule::num_inputs() const {
|
|
return graph_->inputs().size();
|
|
}
|
|
|
|
StaticRuntime& StaticModule::runtime() {
|
|
if (!cached_runtime_) {
|
|
cached_runtime_ = std::make_unique<StaticRuntime>(*this);
|
|
}
|
|
return *cached_runtime_;
|
|
}
|
|
|
|
const std::vector<ProcessedNode>& StaticModule::nodes() const {
|
|
return nodes_;
|
|
}
|
|
|
|
size_t StaticModule::num_nodes() const {
|
|
return nodes_.size();
|
|
}
|
|
|
|
Node* StaticModule::findNodeWithKindForTesting(const std::string& kind) const {
|
|
for (auto& pnode : nodes()) {
|
|
if (pnode.node()->kind().toQualString() == kind) {
|
|
return pnode.node();
|
|
}
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
c10::IValue StaticModule::operator()(
|
|
const std::vector<c10::IValue>& args,
|
|
const KeywordArgs& kwargs) {
|
|
return runtime()(args, kwargs);
|
|
}
|
|
|
|
c10::IValue StaticModule::operator()(
|
|
std::vector<c10::IValue>&& args,
|
|
const KeywordArgs& kwargs) {
|
|
return runtime()(std::move(args), kwargs);
|
|
}
|
|
|
|
StaticRuntime::StaticRuntime(const StaticModule& sm)
|
|
: static_module_(sm),
|
|
manage_output_tensors_enabled_(sm.opts().manage_output_tensors),
|
|
nodes_(sm.nodes()) {
|
|
const auto total_num_node_outputs = std::accumulate(
|
|
nodes_.begin(),
|
|
nodes_.end(),
|
|
0,
|
|
[](uint32_t sum, const ProcessedNode& pnode) {
|
|
return sum + pnode.num_outputs();
|
|
});
|
|
values_.resize(
|
|
sm.num_inputs() + sm.constants().size() + total_num_node_outputs);
|
|
const auto inputs_index_offset = 0;
|
|
const auto constants_index_offset = inputs_index_offset + sm.num_inputs();
|
|
const auto constants_begin_it = values_.begin() + constants_index_offset;
|
|
const auto constants_end_it = constants_begin_it + sm.constants().size();
|
|
std::copy(sm.constants().begin(), sm.constants().end(), constants_begin_it);
|
|
|
|
for (const auto idx : c10::irange(sm.nodes().size())) {
|
|
auto& n = nodes_[idx];
|
|
n.set_values(values_.data());
|
|
}
|
|
|
|
// TODO: can we convert outputs_ to store indices?
|
|
for (auto index : sm.output_indices()) {
|
|
outputs_.emplace_back(&values_[index]);
|
|
}
|
|
}
|
|
|
|
StaticRuntime::~StaticRuntime() = default;
|
|
|
|
void StaticRuntime::set_inputs(
|
|
const std::vector<IValue>& args,
|
|
const KeywordArgs& kwargs) {
|
|
if (!kwargs.empty()) {
|
|
// This is not ideal
|
|
TORCH_CHECK(
|
|
static_module_.schema(),
|
|
"Schema is not available. Consider creating the Static Runtime "
|
|
"with StaticModule(const torch::jit::Module& m) instead.");
|
|
std::vector<c10::IValue> stack;
|
|
stack.reserve(static_module_.num_inputs());
|
|
if (static_module_.first_input_is_self()) {
|
|
stack.emplace_back(static_module_.module()._ivalue());
|
|
}
|
|
stack.insert(stack.end(), args.begin(), args.end());
|
|
|
|
static_module_.schema()->checkAndNormalizeInputs(stack, kwargs);
|
|
DCHECK_EQ(static_module_.num_inputs(), stack.size());
|
|
for (const auto i : c10::irange(stack.size())) {
|
|
Input(i) = std::move(stack[i]);
|
|
}
|
|
} else {
|
|
if (static_module_.first_input_is_self()) {
|
|
Input(0) = static_module_.module()._ivalue();
|
|
DCHECK_EQ(static_module_.num_inputs(), args.size() + 1);
|
|
for (const auto i : c10::irange(args.size())) {
|
|
Input(i + 1) = args[i];
|
|
}
|
|
} else {
|
|
DCHECK_EQ(static_module_.num_inputs(), args.size());
|
|
for (const auto i : c10::irange(args.size())) {
|
|
Input(i) = args[i];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void StaticRuntime::set_inputs(
|
|
std::vector<IValue>&& args,
|
|
const KeywordArgs& kwargs) {
|
|
if (!kwargs.empty()) {
|
|
// This is not ideal
|
|
TORCH_CHECK(
|
|
static_module_.schema(),
|
|
"Schema is not available. Consider creating the Static Runtime "
|
|
"with StaticModule(const torch::jit::Module& m) instead.");
|
|
std::vector<c10::IValue> stack;
|
|
stack.reserve(static_module_.num_inputs());
|
|
if (static_module_.first_input_is_self()) {
|
|
stack.emplace_back(static_module_.module()._ivalue());
|
|
}
|
|
stack.insert(
|
|
stack.end(),
|
|
std::make_move_iterator(args.begin()),
|
|
std::make_move_iterator(args.end()));
|
|
|
|
static_module_.schema()->checkAndNormalizeInputs(stack, kwargs);
|
|
DCHECK_EQ(static_module_.num_inputs(), stack.size());
|
|
for (const auto i : c10::irange(stack.size())) {
|
|
Input(i) = std::move(stack[i]);
|
|
}
|
|
} else {
|
|
if (static_module_.first_input_is_self()) {
|
|
Input(0) = static_module_.module()._ivalue();
|
|
DCHECK_EQ(static_module_.num_inputs(), args.size() + 1);
|
|
for (const auto i : c10::irange(args.size())) {
|
|
Input(i + 1) = std::move(args[i]);
|
|
}
|
|
} else {
|
|
DCHECK_EQ(static_module_.num_inputs(), args.size());
|
|
for (const auto i : c10::irange(args.size())) {
|
|
Input(i) = std::move(args[i]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void StaticRuntime::create_memory_planner() {
|
|
if (!planner_) {
|
|
planner_ = std::make_unique<MemoryPlanner>(
|
|
this,
|
|
static_module_.values_share_same_storage(),
|
|
static_module_.value_group(),
|
|
static_module_.managed_tensor_values(),
|
|
static_module_.managed_output_tensor_values(),
|
|
static_module_.leaked_values(),
|
|
static_module_.opts().enable_out_variant,
|
|
manage_output_tensors_enabled_);
|
|
}
|
|
}
|
|
|
|
c10::IValue StaticRuntime::move_outputs_to_tuple(uint32_t num_outputs) {
|
|
#ifndef NDEBUG
|
|
for (const auto i : c10::irange(num_outputs)) {
|
|
// The exact output tensor should never be managed.
|
|
DCHECK(!isManagedOutputTensor(*outputs_[i]));
|
|
}
|
|
#endif
|
|
switch (num_outputs) {
|
|
case 1:
|
|
return c10::ivalue::Tuple::create(std::move(*outputs_[0]));
|
|
case 2:
|
|
return c10::ivalue::Tuple::create(
|
|
std::move(*outputs_[0]), std::move(*outputs_[1]));
|
|
case 3:
|
|
return c10::ivalue::Tuple::create(
|
|
std::move(*outputs_[0]),
|
|
std::move(*outputs_[1]),
|
|
std::move(*outputs_[2]));
|
|
default: {
|
|
std::vector<c10::IValue> outputs;
|
|
outputs.reserve(num_outputs);
|
|
for (const auto i : c10::irange(num_outputs)) {
|
|
// use move here. Otherwise, clean up outputs_[i] explicitly
|
|
outputs.emplace_back(std::move(*outputs_[i]));
|
|
}
|
|
return c10::ivalue::Tuple::create(std::move(outputs));
|
|
}
|
|
}
|
|
}
|
|
|
|
/// [Check and correct bad schema alias info at runtime]
|
|
/// Static runtime relies on the operator schema's alias info to be correct for
|
|
/// memory planning. Because it's hard to enforce the alias info to be correct,
|
|
/// we need to do runtime detection for accidental aliases that do not comply
|
|
/// with the schema. Only aliases of managed tensors are problematic. To avoid
|
|
/// runtime crashes, we can add runtime detection and force the op to comply
|
|
/// with its schema by cloning the alias. Because all managed tensors' data_ptrs
|
|
/// are part of the internal buffer that the MemoryPlanner allocates, we can
|
|
/// check aliases by checking the memory overlap with this internal buffer. But
|
|
/// a tensor's storage can be resized during inferenceso we need another way to
|
|
/// handle the resized case.
|
|
///
|
|
/// There are two ways for incorrect schema to break memory planning. Let's look
|
|
/// at two examples:
|
|
///
|
|
/// Example 1:
|
|
/// @code
|
|
/// def forward(x):
|
|
/// a = x + x
|
|
/// b = bad_op(a) # b ends up aliasing a incorrectly
|
|
/// return (b)
|
|
/// @endcode
|
|
/// bad_op: its schema says it returns a new Tensor, but it actually returns an
|
|
/// alias. In this case, the memory planner would recognize `a` as a managed
|
|
/// tensor and clean up its memory before returning `b`. But `b` is actually an
|
|
/// alias of `a`, when `a`'s data_ptr get reset, `b`'s data_ptr gets reset too.
|
|
///
|
|
/// Example 2:
|
|
/// @code
|
|
/// def forward(x):
|
|
/// a = x + x
|
|
/// a2 = bad_op(a) # a2 ends up alias a incorrectly
|
|
/// b = a + a
|
|
/// c = b * b # c shares storage with a
|
|
/// d = c + 2 # d shares storage with b
|
|
/// e = a2 * a2
|
|
/// return (d, e)
|
|
/// @endcode
|
|
/// With the memory reuse algorithm, `c` could end up sharing storage with `a`,
|
|
/// but because of bad_op, `a2` now aliases `a`. `c` overwrites `a` and
|
|
/// therefore `a2`, leading to the wrong results. We solve this problem with two
|
|
/// steps. Note this doesn't happen with the current memory reuse algorithm
|
|
/// because of the way it's implemented. Things could change with a different
|
|
/// implementation.
|
|
///
|
|
/// Step 1, annotate the ProcessedNodes with a flag `check_memory_overlap_` set
|
|
/// to true if its outputs do not alias its inputs as indicated by the AliasDb
|
|
/// and all of its outputs are Tensors. Then at runtime, we check that the
|
|
/// nodes' output tensors do not overlap with the internal buffer that the
|
|
/// MemoryPlanner allocates. For latency concerns, we only run this check for
|
|
/// fallback ops. The schemas of native ops and out variants are vetted and
|
|
/// enforced with static runtime unit tests. For the first iteration, we do a
|
|
/// full memory overlap check with
|
|
/// ProcessedNode::verify_and_correct_memory_overlap() because the internal
|
|
/// buffer doesn't exist yet.
|
|
///
|
|
/// Step 2, if a managed tensor gets resized during inference, it gets a new
|
|
/// data_ptr which is not from the buffer. We can tackle this corner case by
|
|
/// delaying the deallocation of the managed tensors to after the outputs are no
|
|
/// longer used (essentially merging the internal/output buffers into one).
|
|
/// Before the merging is implemented, we add another flag `overlap_detected_`
|
|
/// to flag any node with overlap detected in Step 1 and do a full memory
|
|
/// overlap check if the fast check (by checking memory overlap with internal
|
|
/// buffer) fails. There is still a corner case that fails with the added flag.
|
|
/// If a resize is triggered at the same time as the op creating an alias at the
|
|
/// same time, the current checks would fail to detect the alias.
|
|
///
|
|
/// There is another case of failure that step 2 can prevent. With
|
|
/// StaticModule::opts().cleanup_activations = false, the returned Static
|
|
/// Runtime instance in the instance pool can be re-entered while an unintended
|
|
/// output tensor's alias is still being used by the client (in the
|
|
/// multi-threaded setting). This can only be prevented by delaying the
|
|
/// deallocation and returning the Static Runtime instance after the client is
|
|
/// done with the outputs.
|
|
|
|
void StaticRuntime::verify_and_correct_memory_overlap(ProcessedNode& n) {
|
|
// The slow check can be removed once the internal/output buffers are merged
|
|
if (C10_UNLIKELY(n.check_outputs_for_memory_overlap())) {
|
|
if (C10_UNLIKELY(!planner_ && static_module_.opts().cleanup_activations)) {
|
|
// slow check, for first iter only with cleanup_activations = true
|
|
n.verify_and_correct_memory_overlap();
|
|
} else if (planner_) {
|
|
bool overlap_detected_with_fast_check = false;
|
|
for (size_t i = 0; i < n.outputs().size(); i++) {
|
|
at::Tensor& t = n.Output(i).toTensor();
|
|
if (planner_->overlapWithInternalBuffer(t.data_ptr())) {
|
|
DLOG(INFO) << "Detected alias for node: " << PrintNode(n.node());
|
|
n.Output(i) = at::native::clone(t, c10::nullopt);
|
|
// set flag if overlap detected
|
|
overlap_detected_with_fast_check = true;
|
|
n.set_outputs_memory_overlap_detected();
|
|
}
|
|
}
|
|
if (n.outputs_memory_overlap_detected() &&
|
|
!overlap_detected_with_fast_check) {
|
|
// slow check. Only run when the fast check fails.
|
|
n.verify_and_correct_memory_overlap();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename IValueList>
|
|
c10::IValue StaticRuntime::run_impl(
|
|
IValueList&& args,
|
|
const KeywordArgs& kwargs) {
|
|
// We assume inference workloads, so we do not need
|
|
// autograd. Enabling this is a significant win on dispatcher
|
|
// overhead because it saves a round of dispatch for at least some
|
|
// functions, such as resize_ and resize_as_.
|
|
c10::InferenceMode mode;
|
|
|
|
if (planner_) {
|
|
DCHECK(!manage_output_tensors_enabled_ || checkOutputTensorMemoryLeaks());
|
|
planner_->allocate();
|
|
}
|
|
|
|
set_inputs(std::forward<IValueList>(args), kwargs);
|
|
|
|
// NB: before optimizing the order of execution, ensure that the
|
|
// memory optimization pass (LivenessMap) is
|
|
// aware of the new order!
|
|
for (auto& n : nodes_) {
|
|
// LOG(INFO) << "Running node: " << PrintNode(n.node());
|
|
n.run();
|
|
// Check for incorrect schema alias info.
|
|
verify_and_correct_memory_overlap(n);
|
|
}
|
|
|
|
if (static_module_.opts().cleanup_activations) {
|
|
// MemoryPlanner is created after the first invocation of `run()`. This is
|
|
// done intentionally because MemoryPlanner uses `Tensor` sizes of the
|
|
// previous `run()` for memory planning of subsequent runs
|
|
create_memory_planner();
|
|
planner_->deallocate();
|
|
// clean up owning refs of input tensors
|
|
clean_up_input_ivalues();
|
|
}
|
|
|
|
// no need to keep references of outputs in static runtime anymore
|
|
if (static_module_.num_outputs() > 1) {
|
|
return move_outputs_to_tuple(static_module_.num_outputs());
|
|
}
|
|
#ifndef NDEBUG
|
|
check_for_memory_leak(false);
|
|
#endif
|
|
// The exact output tensor should never be managed.
|
|
DCHECK(!isManagedOutputTensor(*outputs_[0]));
|
|
// use move here. Otherwise, clean up outputs_[0] explicitly
|
|
return std::move(*outputs_[0]);
|
|
}
|
|
|
|
template <typename IValueList>
|
|
c10::IValue StaticRuntime::run_impl_record_functions(
|
|
IValueList&& args,
|
|
const KeywordArgs& kwargs) {
|
|
bool pre_sampled = false;
|
|
if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) {
|
|
at::RecordFunction guard(
|
|
at::RecordScope::TORCHSCRIPT_FUNCTION, pre_sampled);
|
|
if (guard.isActive()) {
|
|
if (guard.needsInputs()) {
|
|
guard.before("forward", &args);
|
|
} else {
|
|
guard.before("forward");
|
|
}
|
|
}
|
|
return run_impl(std::forward<IValueList>(args), kwargs);
|
|
}
|
|
return run_impl(std::forward<IValueList>(args), kwargs);
|
|
}
|
|
|
|
c10::IValue StaticRuntime::operator()(
|
|
const std::vector<c10::IValue>& args,
|
|
const KeywordArgs& kwargs) {
|
|
#ifdef PYTORCH_DISABLE_NET_PROFILING
|
|
return run_impl(args, kwargs);
|
|
#else
|
|
return run_impl_record_functions(args, kwargs);
|
|
#endif
|
|
}
|
|
|
|
c10::IValue StaticRuntime::operator()(
|
|
std::vector<c10::IValue>&& args,
|
|
const KeywordArgs& kwargs) {
|
|
#ifdef PYTORCH_DISABLE_NET_PROFILING
|
|
return run_impl(std::move(args), kwargs);
|
|
#else
|
|
return run_impl_record_functions(std::move(args), kwargs);
|
|
#endif
|
|
}
|
|
|
|
namespace {
|
|
|
|
std::string generate_latency_json(const std::string& label, double millis) {
|
|
#ifdef FBCODE_CAFFE2
|
|
folly::dynamic json = folly::dynamic::object();
|
|
json["type"] = label;
|
|
json["metric"] = "latency";
|
|
json["unit"] = "ms";
|
|
json["value"] = millis;
|
|
return "PyTorchObserver " + folly::toJson(json);
|
|
#else
|
|
return "";
|
|
#endif
|
|
}
|
|
|
|
} // namespace
|
|
|
|
void StaticRuntime::benchmark(
|
|
const std::vector<std::vector<c10::IValue>>& args_list,
|
|
const std::vector<KeywordArgs>& kwargs_list,
|
|
const int warmup_runs,
|
|
const int main_runs,
|
|
bool print_per_node_time,
|
|
bool generate_ai_pep_output) {
|
|
TORCH_CHECK(
|
|
kwargs_list.size() == 0 || args_list.size() == kwargs_list.size());
|
|
std::cout << "Input size: " << args_list.size() << std::endl;
|
|
if (args_list.size() == 0) {
|
|
return;
|
|
}
|
|
float time_per_iter =
|
|
benchmark_model(args_list, kwargs_list, warmup_runs, main_runs);
|
|
std::cout << "Static runtime ms per iter: " << time_per_iter
|
|
<< ". Iters per second: " << 1000.0 / time_per_iter << std::endl;
|
|
|
|
IndividualMetrics results =
|
|
benchmark_individual_ops(args_list, kwargs_list, warmup_runs, main_runs);
|
|
|
|
if (print_per_node_time) {
|
|
for (const auto i : c10::irange(nodes_.size())) {
|
|
const Node* node = nodes_[i].node();
|
|
std::cout << "Node #" << i << ": " << results.time_per_node[i]
|
|
<< " ms/iter, ";
|
|
node->print(std::cout, 0, nullptr, false);
|
|
}
|
|
}
|
|
|
|
std::vector<std::pair<std::string, double>> time_per_node_type_vec{
|
|
results.time_per_node_type.begin(), results.time_per_node_type.end()};
|
|
std::sort(
|
|
time_per_node_type_vec.begin(),
|
|
time_per_node_type_vec.end(),
|
|
[](auto& left, auto& right) { return left.second > right.second; });
|
|
|
|
std::cout << "Time per node type:" << std::endl;
|
|
for (const auto& p : time_per_node_type_vec) {
|
|
const std::string& kind = p.first;
|
|
const double ms = p.second;
|
|
std::cout << std::setw(15) << ms << " ms. " << std::setw(10)
|
|
<< results.percent_per_node_type[kind] << "%. " << kind << " ("
|
|
<< results.instances_per_node_type[kind] << " nodes";
|
|
if (results.out_nodes.count(kind)) {
|
|
std::cout << ", out variant)" << std::endl;
|
|
} else if (results.native_nodes.count(kind)) {
|
|
std::cout << ", native)" << std::endl;
|
|
} else {
|
|
std::cout << ")" << std::endl;
|
|
}
|
|
|
|
if (generate_ai_pep_output) {
|
|
LOG(INFO) << generate_latency_json(kind, ms);
|
|
}
|
|
}
|
|
if (generate_ai_pep_output) {
|
|
LOG(INFO) << generate_latency_json(
|
|
"static_runtime_first_iter", results.first_iter_time);
|
|
}
|
|
std::cout << std::setw(15) << results.total_time << " ms. in Total"
|
|
<< std::endl;
|
|
std::cout << "StaticRuntime setup time: " << results.setup_time << " ms"
|
|
<< std::endl;
|
|
std::cout << "Memory allocation time: " << results.memory_alloc_time
|
|
<< " ms\n";
|
|
std::cout << "Memory deallocation time: " << results.memory_dealloc_time
|
|
<< " ms" << std::endl;
|
|
std::cout << "Outputs deallocation time: " << results.output_dealloc_time
|
|
<< " ms" << std::endl;
|
|
std::cout << "First iter time: " << results.first_iter_time << " ms"
|
|
<< std::endl;
|
|
std::cout << "Number of operators: " << nodes_.size() << std::endl;
|
|
|
|
if (planner_) {
|
|
std::cout << "Total number of managed tensors: "
|
|
<< planner_->total_num_managed_tensors() << std::endl;
|
|
std::cout << "Total number of managed output tensors: "
|
|
<< planner_->total_num_managed_output_tensors() << std::endl;
|
|
std::cout << "Total number of unmanaged values: "
|
|
<< planner_->total_num_unmanaged() << std::endl;
|
|
std::cout << "Number of unmanaged values requiring cleanup: "
|
|
<< planner_->num_unmanaged_non_scalars() << std::endl;
|
|
std::cout << "Number of unmanaged values not requiring cleanup: "
|
|
<< planner_->num_unmanaged_scalars() << std::endl;
|
|
std::cout << "Total memory managed: " << planner_->total_managed()
|
|
<< " bytes" << std::endl;
|
|
if (static_module_.opts().optimize_memory) {
|
|
std::cout << "Total number of reused tensors: "
|
|
<< planner_->total_reused_tensors() << std::endl;
|
|
}
|
|
std::cout << "Total number of 'out' variant nodes/total number of nodes: "
|
|
<< results.out_nodes_count << "/" << results.total_nodes_count
|
|
<< " ("
|
|
<< 100.0 * (results.out_nodes_count) /
|
|
static_cast<float>(results.total_nodes_count)
|
|
<< "%)" << std::endl;
|
|
}
|
|
check_for_memory_leak();
|
|
|
|
#ifndef NDEBUG
|
|
KeywordArgs empty_kwargs;
|
|
display_nodes(
|
|
args_list[0], kwargs_list.size() > 0 ? kwargs_list[0] : empty_kwargs);
|
|
#endif
|
|
}
|
|
|
|
float StaticRuntime::benchmark_model(
|
|
const std::vector<std::vector<c10::IValue>>& args_list,
|
|
const std::vector<KeywordArgs>& kwargs_list,
|
|
const int warmup_runs,
|
|
const int main_runs) {
|
|
TORCH_CHECK(warmup_runs >= 0 && main_runs >= 1);
|
|
TORCH_CHECK(
|
|
kwargs_list.size() == 0 || args_list.size() == kwargs_list.size());
|
|
|
|
const bool is_kwargs_empty = kwargs_list.size() == 0;
|
|
const KeywordArgs empty_kwargs;
|
|
for (const auto i : c10::irange(warmup_runs)) {
|
|
(void)i; // Suppress unused variable warning
|
|
for (const auto j : c10::irange(args_list.size())) {
|
|
operator()(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]);
|
|
if (manage_output_tensors_enabled_) {
|
|
deallocateOutputTensors();
|
|
}
|
|
}
|
|
}
|
|
caffe2::Timer timer;
|
|
for (const auto i : c10::irange(main_runs)) {
|
|
(void)i; // Suppress unused variable warning
|
|
for (const auto j : c10::irange(args_list.size())) {
|
|
operator()(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]);
|
|
if (manage_output_tensors_enabled_) {
|
|
deallocateOutputTensors();
|
|
}
|
|
}
|
|
}
|
|
float millis = timer.MilliSeconds();
|
|
return millis / (static_cast<float>(main_runs) * args_list.size());
|
|
}
|
|
|
|
bool display_ivalue(const IValue& iv) {
|
|
if (iv.isTensor()) {
|
|
std::cout << "Tensor " << iv.toTensor().toString() << " {";
|
|
for (const auto i : c10::irange(iv.toTensor().sizes().size())) {
|
|
std::cout << iv.toTensor().sizes()[i];
|
|
if (iv.toTensor().sizes().size() > i + 1) {
|
|
std::cout << ", ";
|
|
}
|
|
}
|
|
std::cout << "}\n";
|
|
return true;
|
|
} else if (iv.isTensorList()) {
|
|
std::cout << "TensorList {" << iv.toTensorList().size() << "}\n";
|
|
return true;
|
|
} else if (iv.isGenericDict()) {
|
|
std::cout << "Dict {" << iv.toGenericDict().size() << "}\n";
|
|
return true;
|
|
} else if (iv.isTuple()) {
|
|
std::cout << "Tuple {" << iv.toTupleRef().elements().size() << "}\n";
|
|
return true;
|
|
} else if (iv.isInt()) {
|
|
std::cout << "int {" << iv.toInt() << "}\n";
|
|
return true;
|
|
} else if (iv.isBool()) {
|
|
std::cout << "bool {" << iv.toBool() << "}\n";
|
|
return true;
|
|
} else if (iv.isDouble()) {
|
|
std::cout << "double {" << iv.toDouble() << "}\n";
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
void display_pnode_info(const ProcessedNode& pnode) {
|
|
pnode.node()->print(std::cout, 0, nullptr, false);
|
|
for (const auto i : c10::irange(pnode.num_inputs())) {
|
|
std::cout << "\ti" << i << ": ";
|
|
if (!display_ivalue(pnode.Input(i))) {
|
|
std::cout << *(pnode.node()->inputs()[i]->type()) << '\n';
|
|
}
|
|
}
|
|
const auto outputs = pnode.outputs();
|
|
for (const auto i : c10::irange(outputs.size())) {
|
|
std::cout << "\to" << i << ": ";
|
|
if (!display_ivalue(outputs[i])) {
|
|
std::cout << *(pnode.node()->outputs()[i]->type()) << '\n';
|
|
}
|
|
}
|
|
}
|
|
|
|
void StaticRuntime::display_nodes(
|
|
const std::vector<c10::IValue>& args,
|
|
const KeywordArgs& kwargs) {
|
|
c10::InferenceMode mode;
|
|
if (planner_) {
|
|
planner_->allocate();
|
|
}
|
|
set_inputs(args, kwargs);
|
|
|
|
for (auto& node : nodes_) {
|
|
node.run();
|
|
display_pnode_info(node);
|
|
}
|
|
|
|
if (static_module_.opts().cleanup_activations) {
|
|
// MemoryPlanner is created after the first invocation of `run()`. This is
|
|
// done intentionally because MemoryPlanner uses `Tensor` sizes of the
|
|
// previous `run()` for memory planning of subsequent runs
|
|
create_memory_planner();
|
|
planner_->deallocate();
|
|
// clean up owning refs of input tensors
|
|
clean_up_input_ivalues();
|
|
}
|
|
}
|
|
|
|
StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
|
|
const std::vector<std::vector<c10::IValue>>& args_list,
|
|
const std::vector<KeywordArgs>& kwargs_list,
|
|
const int warmup_runs,
|
|
const int main_runs) {
|
|
TORCH_CHECK(
|
|
kwargs_list.size() == 0 || args_list.size() == kwargs_list.size());
|
|
TORCH_CHECK(warmup_runs >= 1 && main_runs >= 1);
|
|
if (args_list.size() == 0) {
|
|
return {};
|
|
}
|
|
|
|
const bool is_kwargs_empty = kwargs_list.size() == 0;
|
|
const KeywordArgs empty_kwargs;
|
|
bool manage_output_tensors = static_module_.opts().manage_output_tensors;
|
|
// See comment on above use of InferenceMode for
|
|
// explanation.
|
|
c10::InferenceMode mode;
|
|
|
|
IndividualMetrics results;
|
|
results.time_per_node.resize(nodes_.size(), 0);
|
|
|
|
// setup time
|
|
caffe2::Timer timer;
|
|
|
|
set_inputs(args_list[0], is_kwargs_empty ? empty_kwargs : kwargs_list[0]);
|
|
|
|
results.setup_time = timer.MilliSeconds();
|
|
|
|
// The first iteration profiles each node's output Tensors' sizes and
|
|
// initializes the memory planner with the profile information. Folllowing
|
|
// iterations just use the already established memory planning.
|
|
timer.Start();
|
|
operator()(args_list[0], is_kwargs_empty ? empty_kwargs : kwargs_list[0]);
|
|
if (manage_output_tensors) {
|
|
deallocateOutputTensors();
|
|
}
|
|
results.first_iter_time = timer.MilliSeconds();
|
|
|
|
// warmup runs
|
|
for (const auto i : c10::irange(warmup_runs - 1)) {
|
|
(void)i; // Suppress unused variable warning
|
|
for (const auto j : c10::irange(args_list.size())) {
|
|
operator()(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]);
|
|
if (manage_output_tensors) {
|
|
deallocateOutputTensors();
|
|
}
|
|
}
|
|
}
|
|
|
|
// main runs
|
|
for (const auto i : c10::irange(main_runs)) {
|
|
(void)i; // Suppress unused variable warning
|
|
|
|
for (const auto j : c10::irange(args_list.size())) {
|
|
set_inputs(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]);
|
|
|
|
timer.Start();
|
|
if (planner_) {
|
|
planner_->allocate();
|
|
}
|
|
float millis = timer.MilliSeconds();
|
|
results.memory_alloc_time += millis;
|
|
|
|
for (const auto k : c10::irange(nodes_.size())) {
|
|
timer.Start();
|
|
nodes_[k].run();
|
|
millis = timer.MilliSeconds();
|
|
results.time_per_node[k] += millis;
|
|
}
|
|
timer.Start();
|
|
if (static_module_.opts().cleanup_activations) {
|
|
create_memory_planner();
|
|
planner_->deallocate();
|
|
// clean up owning refs of input tensors
|
|
clean_up_input_ivalues();
|
|
}
|
|
if (manage_output_tensors) {
|
|
deallocateOutputTensors();
|
|
}
|
|
millis = timer.MilliSeconds();
|
|
results.memory_dealloc_time += millis;
|
|
|
|
timer.Start();
|
|
// no need to keep references of outputs in static runtime anymore
|
|
c10::IValue output;
|
|
if (static_module_.num_outputs() > 1) {
|
|
output = move_outputs_to_tuple(static_module_.num_outputs());
|
|
}
|
|
|
|
#ifndef NDEBUG
|
|
check_for_memory_leak(false);
|
|
#endif
|
|
|
|
// use move here. Otherwise, clean up outputs_[0] explicitly
|
|
output = std::move(*outputs_[0]);
|
|
// release outputs explicitly to measure the time it takes
|
|
output = IValue();
|
|
millis = timer.MilliSeconds();
|
|
results.output_dealloc_time += millis;
|
|
}
|
|
}
|
|
|
|
// post processing
|
|
const float num_total_iters =
|
|
(static_cast<float>(main_runs) * args_list.size());
|
|
for (const auto i : c10::irange(nodes_.size())) {
|
|
const Node* node = nodes_[i].node();
|
|
std::string kind = std::string(node->kind().toQualString());
|
|
results.time_per_node[i] /= num_total_iters;
|
|
results.time_per_node_type[kind] += results.time_per_node[i];
|
|
results.instances_per_node_type[kind]++;
|
|
if (nodes_[i].has_out_variant()) {
|
|
results.out_nodes.insert(kind);
|
|
results.out_nodes_count++;
|
|
} else if (nodes_[i].has_native()) {
|
|
results.native_nodes.insert(kind);
|
|
}
|
|
results.total_time += results.time_per_node[i];
|
|
}
|
|
results.total_nodes_count = nodes_.size();
|
|
results.memory_alloc_time /= num_total_iters;
|
|
results.memory_dealloc_time /= num_total_iters;
|
|
results.output_dealloc_time /= num_total_iters;
|
|
for (const auto& p : results.time_per_node_type) {
|
|
const std::string& kind = p.first;
|
|
results.percent_per_node_type[kind] = p.second / results.total_time * 100;
|
|
}
|
|
return results;
|
|
}
|
|
|
|
const std::vector<ProcessedNode>& StaticRuntime::nodes() const {
|
|
return nodes_;
|
|
}
|
|
|
|
std::vector<ProcessedNode>& StaticRuntime::nodes() {
|
|
return nodes_;
|
|
}
|
|
|
|
void StaticRuntime::check_for_memory_leak(bool output_returned) {
|
|
if (!static_module_.opts().cleanup_activations) {
|
|
return;
|
|
}
|
|
|
|
// check for inputs
|
|
for (const auto i : c10::irange(static_module_.num_inputs())) {
|
|
TORCH_CHECK(values_[i].isNone(), "Input ", i, " was not cleaned up");
|
|
}
|
|
FastSet<const IValue*> output_ivalues(outputs_.begin(), outputs_.end());
|
|
for (const auto n : c10::irange(nodes_.size())) {
|
|
auto& pnode = nodes_[n];
|
|
for (const auto i : c10::irange(pnode.num_outputs())) {
|
|
const IValue* ival = &pnode.Output(i);
|
|
const Value* val = pnode.node()->output(i);
|
|
if (planner_ && isManagedOutputTensorValue(val)) {
|
|
// `ival` contains a managed output tensor that the runtime doesn't
|
|
// reclaim at the end of an iteration, but the client does so
|
|
// by explicitly calling `StaticRuntime::deallocateOutputTensors`.
|
|
continue;
|
|
}
|
|
const std::string error_msg = "Output " + c10::to_string(i) + ", %" +
|
|
val->debugName() + " of node " + c10::to_string(n) +
|
|
" was not cleaned up";
|
|
if (output_ivalues.count(ival) == 0) {
|
|
// check for intermediates
|
|
if (!ival->isNone()) {
|
|
TORCH_CHECK(
|
|
ival->isTensor() ||
|
|
static_module_.is_optimizable_container_type(pnode.node()) ||
|
|
doesNotHeapAllocateWhenStoredInIValue(*val->type()),
|
|
error_msg);
|
|
if (ival->isTensor()) {
|
|
const auto& t = ival->toTensor();
|
|
if (t.defined()) {
|
|
auto* storage_impl = t.storage().unsafeGetStorageImpl();
|
|
TORCH_CHECK(
|
|
storage_impl->data() == nullptr ||
|
|
(planner_ &&
|
|
planner_->isManagedStorageImpl(storage_impl)),
|
|
error_msg);
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
// check for outputs
|
|
if (output_returned) {
|
|
TORCH_CHECK(ival->isNone(), error_msg);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
VLOG(1) << "Finished checking for memory leak";
|
|
}
|
|
|
|
void StaticRuntime::deallocateOutputTensors() {
|
|
if (!static_module_.opts().manage_output_tensors) {
|
|
TORCH_CHECK(
|
|
!planner_ || planner_->numOutputBufferBytes() == 0,
|
|
"manage_output_tensors is disabled, but output tensor buffer is not empty.");
|
|
return;
|
|
}
|
|
if (planner_) {
|
|
planner_->deallocateOutputTensors();
|
|
DCHECK(checkOutputTensorMemoryLeaks());
|
|
}
|
|
}
|
|
|
|
bool StaticRuntime::checkOutputTensorMemoryLeaks() {
|
|
if (!static_module_.opts().manage_output_tensors || !planner_) {
|
|
return true;
|
|
}
|
|
for (const auto n : c10::irange(nodes_.size())) {
|
|
auto& pnode = nodes_[n];
|
|
for (const auto i : c10::irange(pnode.num_outputs())) {
|
|
const IValue* ival = &pnode.Output(i);
|
|
const Value* val = pnode.node()->output(i);
|
|
if (!isManagedOutputTensorValue(val)) {
|
|
continue;
|
|
}
|
|
const auto& t = ival->toTensor();
|
|
if (t.defined()) {
|
|
auto* storage_impl = t.storage().unsafeGetStorageImpl();
|
|
const std::string error_msg = "Output " + c10::to_string(i) + ", %" +
|
|
val->debugName() + " of node " + c10::to_string(n) +
|
|
" was not cleaned up";
|
|
TORCH_CHECK(storage_impl->data() == nullptr, error_msg);
|
|
}
|
|
}
|
|
}
|
|
VLOG(1) << "Finished checking for memory leak from output tensors";
|
|
return true;
|
|
}
|
|
|
|
bool StaticRuntime::isManagedOutputTensor(const IValue& ivalue) const {
|
|
return planner_ && planner_->isManagedOutputTensor(ivalue);
|
|
}
|
|
|
|
bool StaticRuntime::isManagedOutputTensorValue(const Value* value) const {
|
|
// It's possible that manage_output_tensors_ was disabled after initializing
|
|
// managed_output_tensor_values, so we have to check that flag here.
|
|
if (!planner_ || !manage_output_tensors_enabled_) {
|
|
return false;
|
|
}
|
|
const auto& managed_outputs = static_module_.managed_output_tensor_values();
|
|
return managed_outputs.find(value) != managed_outputs.end();
|
|
}
|
|
|
|
void StaticRuntime::disableManageOutputTensors() {
|
|
if (!manage_output_tensors_enabled_) {
|
|
return;
|
|
}
|
|
manage_output_tensors_enabled_ = false;
|
|
if (!planner_) {
|
|
return;
|
|
}
|
|
// Reset all IValues and destruct planner_ so that it can be reconstructed in
|
|
// the next run.
|
|
for (auto& n : nodes_) {
|
|
for (const auto i : c10::irange(n.outputs().size())) {
|
|
n.Output(i) = IValue();
|
|
}
|
|
}
|
|
planner_.reset();
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|