More name refactoring of memory planning codes to make it more readable (#54272)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54272

Test Plan: Imported from OSS

Reviewed By: bwasti

Differential Revision: D27233881

fbshipit-source-id: f257f16ac0684df055961e539f17d002cb8f1bfe
This commit is contained in:
Peng Wu 2021-03-24 19:51:13 -07:00 committed by Facebook GitHub Bot
parent 1ceb90405b
commit fe2c1268b7
2 changed files with 84 additions and 61 deletions

View File

@ -125,35 +125,38 @@ using LivenessInformation = std::pair<
LivenessInformation GetLivenessInformation(
const std::shared_ptr<torch::jit::Graph>& graph,
AliasDb& db) {
// map a Value to a set of Values that overlap live-ranges with the Value's
std::unordered_map<const Value*, std::set<const Value*>> liveness_map;
// a set of Values whose live-range exceed current inference
std::unordered_set<const Value*> always_alive;
// map Values to its creation order in graph (Note: only traverse top-level
// nodes such that nodes under control-flows are represented by top-level
// block nodes)
std::vector<const Value*> values_in_creation_order;
std::unordered_map<const Value*, size_t> values_in_creation_order_idx;
std::unordered_map<const Value*, size_t> values_to_idx_in_creation_order;
for (const auto* node : graph->nodes()) {
for (const auto* v : node->outputs()) {
values_in_creation_order_idx[v] = values_in_creation_order.size();
values_to_idx_in_creation_order[v] = values_in_creation_order.size();
values_in_creation_order.emplace_back(v);
}
}
// maps values to any nodes that consume or produce them
//
// updated as we traverse the graph. the presence of a key in `live_values`
// means that the value is currently alive.
//
// invariant: set.size() > 0
std::unordered_map<const Value*, std::set<const Node*>> live_values;
std::unordered_map<const Node*, std::set<const Value*>> live_nodes;
// presence of a Value in live_values_use_chain means the Value alive
// Value mapped to set of Nodes that may use the Value (i.e., use-chain of
// Value)
std::unordered_map<const Value*, std::set<const Node*>> live_values_use_chain;
// Node mapped to set of Values that the Node may use (i.e., def-chain of node
// inputs)
std::unordered_map<const Node*, std::set<const Value*>> live_nodes_def_chain;
// inputs and outputs are marked permanently alive
// mark inputs, constants, outputs as always_alive
for (const auto* input : graph->inputs()) {
always_alive.insert(input);
}
for (const auto* output : graph->outputs()) {
always_alive.insert(output);
}
for (const auto* node : graph->nodes()) {
if (node->kind() == prim::Constant) {
for (const auto* output : node->outputs()) {
@ -162,13 +165,14 @@ LivenessInformation GetLivenessInformation(
}
}
// add v to the current liveness_map
std::function<void(const Value* v)> add_live_value_fn = [&](const Value* v) {
if (liveness_map.count(v)) {
return;
}
liveness_map[v] = {};
for (const auto& live_v : live_values) {
for (const auto& live_v : live_values_use_chain) {
liveness_map.at(v).insert(live_v.first);
liveness_map.at(live_v.first).insert(v);
}
@ -176,45 +180,53 @@ LivenessInformation GetLivenessInformation(
// only add values to the live set if they
// have deps, otherwise they die immediately
if (v->uses().size()) {
live_values[v] = {};
live_values_use_chain[v] = {};
}
// record the relationship between v (Value) and its uses (Node)
for (const auto& u : v->uses()) {
const auto* node = u.user;
// track deps of this value
live_values.at(v).insert(node);
live_nodes[node].insert(v);
live_values_use_chain.at(v).insert(node);
live_nodes_def_chain[node].insert(v);
}
// values created after this one that alias it
std::vector<const Value*> aliased_vs;
auto idx = values_in_creation_order_idx[v];
// FIXME(penguin): the following alias refinement seems to assume
// that `v` refers to a new tensor created by the node that defines
// v, thus other Values "before" the node that defines `v` cannot
// possibly be aliased to `v`.
// TODO(penguin): Is it a limitation of TS alias analysis
// so that we need to do such refinement? If so, better improve
// alias analysis so that we dont need this special handling here
//
// Refine aliases of v by include only those created after v
std::vector<const Value*> refined_aliases;
auto idx = values_to_idx_in_creation_order[v];
for (; idx < values_in_creation_order.size(); ++idx) {
auto* alias_v = values_in_creation_order[idx];
if (mayContainAlias(db, v, alias_v)) {
aliased_vs.emplace_back(alias_v);
refined_aliases.emplace_back(alias_v);
}
}
// for all the values in the alias set,
// we set them "alive"
for (auto* aliased_v : aliased_vs) {
for (auto* aliased_v : refined_aliases) {
add_live_value_fn(aliased_v);
for (const auto& u : aliased_v->uses()) {
const auto* node = u.user;
// track deps of the aliased values is if they
// are our own
live_values.at(v).insert(node);
live_nodes[node].insert(v);
live_values_use_chain.at(v).insert(node);
live_nodes_def_chain[node].insert(v);
}
}
};
auto traverse_node_fn = [&](const Node* node,
std::vector<const Value*>& dead) {
if (live_nodes.count(node)) {
for (const auto* v : live_nodes.at(node)) {
live_values.at(v).erase(node);
if (!live_values.at(v).size()) {
if (live_nodes_def_chain.count(node)) {
for (const auto* v : live_nodes_def_chain.at(node)) {
live_values_use_chain.at(v).erase(node);
if (!live_values_use_chain.at(v).size()) {
dead.emplace_back(v);
}
}
@ -233,11 +245,11 @@ LivenessInformation GetLivenessInformation(
std::vector<const Value*> dead;
traverse_node_fn(node, dead);
for (const auto* dead_value : dead) {
live_values.erase(dead_value);
live_values_use_chain.erase(dead_value);
}
}
for (const auto& v : live_values) {
for (const auto& v : live_values_use_chain) {
TORCH_CHECK(always_alive.count(v.first));
}
@ -255,16 +267,16 @@ LivenessInformation GetLivenessInformation(
return std::make_pair(liveness_map, always_alive);
}
// Implementation specific pruning of values
// from "optimzable" set. GetLivenessInformation and FindSameStorageValues
// work with any graph, but we prune out values
// that aren't produced by "_out" variants here.
// Collect the set of Values that are candidates for memory planning:
// - Values that are used in in-place operators (i.e., _out variants), and
// - excluding those that are either inputs or outputs of
// non in-place operators
//
// Returns
// first: Values that can be optimized
// first: Values that are candidates for memory planning
// second: A deterministc order of all values
std::pair<std::vector<const Value*>, std::vector<const Value*>>
GetOptimizableValues(const std::shared_ptr<torch::jit::Graph>& graph) {
GetMemoryPlanningCandidates(const std::shared_ptr<torch::jit::Graph>& graph) {
// for determinism
std::unordered_set<const Value*> seen_values;
std::vector<const Value*> all_values;
@ -334,7 +346,7 @@ GetOptimizableValues(const std::shared_ptr<torch::jit::Graph>& graph) {
// NB: This is a deterministic implementation, which makes it easier to tune
// and debug.
std::unordered_map<const Value*, std::vector<const Value*>>
FindSameStorageValues(
GenerateSameStorageValues(
const LivenessInformation& lm,
const std::pair<std::vector<const Value*>, std::vector<const Value*>>&
optimizable,
@ -399,33 +411,44 @@ FindSameStorageValues(
// to preserve determinism
std::vector<const Value*> seen;
for (const auto* v : optimizable_values) {
if (always_alive.count(v)) {
continue;
}
// get values that are live during the lifetime of v
std::set<const Value*> live;
auto compute_liveset_fn =
[&always_alive, &alive_during, &same_storage_values](
std::set<const Value*>& live, const Value* v) {
for (const auto* sv : same_storage_values.at(v)) {
const auto& l = alive_during.count(sv) ? alive_during.at(sv)
: std::set<const Value*>{};
live.insert(l.begin(), l.end());
}
live.insert(always_alive.begin(), always_alive.end());
};
for (const auto* s : seen) {
// check if any values in this set of same_storage_values
// are alive at the time of v
// effectively finding | set_intersection(live, set_of_shared(s)) | > 0
bool intersects = false;
for (const auto* candidate_v : same_storage_values.at(s)) {
if (live.count(candidate_v)) {
intersects = true;
// check if same_storage_values[s] intersects with live
auto intersect_fn = [&same_storage_values](
std::set<const Value*>& live, const Value* s) {
bool intersect = false;
for (const auto* v : same_storage_values.at(s)) {
if (live.count(v)) {
intersect = true;
break;
}
}
// we can share memory if there's no overlap
if (!intersects) {
return intersect;
};
for (const auto* v : optimizable_values) {
if (always_alive.count(v)) {
continue;
}
// get values that are live during the lifetime of v
std::set<const Value*> live;
compute_liveset_fn(live, v);
for (const auto* s : seen) {
// if live(same_storage_values[v]) and same_storage_values[s]
// do not overlap, then s and v can share the same storage
if (!intersect_fn(live, s)) {
share_storage_fn(v, s);
// since s is added to same_storage_values[v], live needs
// to be recomputed, so bail out here
break;
}
}
@ -556,11 +579,12 @@ StaticModule::StaticModule(
auto lm = GetLivenessInformation(graph_, alias_db);
external_values_ = lm.second;
if (opts_.optimize_memory) {
auto values = GetOptimizableValues(graph_);
auto values = GetMemoryPlanningCandidates(graph_);
if (!opts_.enable_out_variant) {
values.first = {};
}
value_to_same_storage_values_ = FindSameStorageValues(lm, values, alias_db);
value_to_same_storage_values_ =
GenerateSameStorageValues(lm, values, alias_db);
}
}

View File

@ -16,8 +16,8 @@ struct TORCH_API StaticModuleOptions {
bool cleanup_activations{true};
bool enable_out_variant{true};
bool optimize_memory{true};
bool optimize_output_memory{
false}; // to enable MemoryPlanner on output tensors
// to enable MemoryPlanner on output tensors
bool optimize_output_memory{false};
};
/// The static runime supports two execution modes.
@ -81,7 +81,6 @@ class TORCH_API StaticModule {
typedef enum {
CONSTANT_VALUE = -2, // VALUE nodes defined by prim::Constant
INPUT_VALUE = -1, // VALUE nodes representing graph inputs
OTHER_VALUE = 0 // other VALUE nodes (use non-negative index)
} VALUE_KIND;
private: