mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
More name refactoring of memory planning codes to make it more readable (#54272)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54272 Test Plan: Imported from OSS Reviewed By: bwasti Differential Revision: D27233881 fbshipit-source-id: f257f16ac0684df055961e539f17d002cb8f1bfe
This commit is contained in:
parent
1ceb90405b
commit
fe2c1268b7
|
|
@ -125,35 +125,38 @@ using LivenessInformation = std::pair<
|
|||
LivenessInformation GetLivenessInformation(
|
||||
const std::shared_ptr<torch::jit::Graph>& graph,
|
||||
AliasDb& db) {
|
||||
// map a Value to a set of Values that overlap live-ranges with the Value's
|
||||
std::unordered_map<const Value*, std::set<const Value*>> liveness_map;
|
||||
// a set of Values whose live-range exceed current inference
|
||||
std::unordered_set<const Value*> always_alive;
|
||||
|
||||
// map Values to its creation order in graph (Note: only traverse top-level
|
||||
// nodes such that nodes under control-flows are represented by top-level
|
||||
// block nodes)
|
||||
std::vector<const Value*> values_in_creation_order;
|
||||
std::unordered_map<const Value*, size_t> values_in_creation_order_idx;
|
||||
std::unordered_map<const Value*, size_t> values_to_idx_in_creation_order;
|
||||
for (const auto* node : graph->nodes()) {
|
||||
for (const auto* v : node->outputs()) {
|
||||
values_in_creation_order_idx[v] = values_in_creation_order.size();
|
||||
values_to_idx_in_creation_order[v] = values_in_creation_order.size();
|
||||
values_in_creation_order.emplace_back(v);
|
||||
}
|
||||
}
|
||||
|
||||
// maps values to any nodes that consume or produce them
|
||||
//
|
||||
// updated as we traverse the graph. the presence of a key in `live_values`
|
||||
// means that the value is currently alive.
|
||||
//
|
||||
// invariant: set.size() > 0
|
||||
std::unordered_map<const Value*, std::set<const Node*>> live_values;
|
||||
std::unordered_map<const Node*, std::set<const Value*>> live_nodes;
|
||||
// presence of a Value in live_values_use_chain means the Value alive
|
||||
// Value mapped to set of Nodes that may use the Value (i.e., use-chain of
|
||||
// Value)
|
||||
std::unordered_map<const Value*, std::set<const Node*>> live_values_use_chain;
|
||||
// Node mapped to set of Values that the Node may use (i.e., def-chain of node
|
||||
// inputs)
|
||||
std::unordered_map<const Node*, std::set<const Value*>> live_nodes_def_chain;
|
||||
|
||||
// inputs and outputs are marked permanently alive
|
||||
// mark inputs, constants, outputs as always_alive
|
||||
for (const auto* input : graph->inputs()) {
|
||||
always_alive.insert(input);
|
||||
}
|
||||
for (const auto* output : graph->outputs()) {
|
||||
always_alive.insert(output);
|
||||
}
|
||||
|
||||
for (const auto* node : graph->nodes()) {
|
||||
if (node->kind() == prim::Constant) {
|
||||
for (const auto* output : node->outputs()) {
|
||||
|
|
@ -162,13 +165,14 @@ LivenessInformation GetLivenessInformation(
|
|||
}
|
||||
}
|
||||
|
||||
// add v to the current liveness_map
|
||||
std::function<void(const Value* v)> add_live_value_fn = [&](const Value* v) {
|
||||
if (liveness_map.count(v)) {
|
||||
return;
|
||||
}
|
||||
liveness_map[v] = {};
|
||||
|
||||
for (const auto& live_v : live_values) {
|
||||
for (const auto& live_v : live_values_use_chain) {
|
||||
liveness_map.at(v).insert(live_v.first);
|
||||
liveness_map.at(live_v.first).insert(v);
|
||||
}
|
||||
|
|
@ -176,45 +180,53 @@ LivenessInformation GetLivenessInformation(
|
|||
// only add values to the live set if they
|
||||
// have deps, otherwise they die immediately
|
||||
if (v->uses().size()) {
|
||||
live_values[v] = {};
|
||||
live_values_use_chain[v] = {};
|
||||
}
|
||||
|
||||
// record the relationship between v (Value) and its uses (Node)
|
||||
for (const auto& u : v->uses()) {
|
||||
const auto* node = u.user;
|
||||
// track deps of this value
|
||||
live_values.at(v).insert(node);
|
||||
live_nodes[node].insert(v);
|
||||
live_values_use_chain.at(v).insert(node);
|
||||
live_nodes_def_chain[node].insert(v);
|
||||
}
|
||||
|
||||
// values created after this one that alias it
|
||||
std::vector<const Value*> aliased_vs;
|
||||
auto idx = values_in_creation_order_idx[v];
|
||||
// FIXME(penguin): the following alias refinement seems to assume
|
||||
// that `v` refers to a new tensor created by the node that defines
|
||||
// v, thus other Values "before" the node that defines `v` cannot
|
||||
// possibly be aliased to `v`.
|
||||
// TODO(penguin): Is it a limitation of TS alias analysis
|
||||
// so that we need to do such refinement? If so, better improve
|
||||
// alias analysis so that we dont need this special handling here
|
||||
//
|
||||
// Refine aliases of v by include only those created after v
|
||||
std::vector<const Value*> refined_aliases;
|
||||
auto idx = values_to_idx_in_creation_order[v];
|
||||
for (; idx < values_in_creation_order.size(); ++idx) {
|
||||
auto* alias_v = values_in_creation_order[idx];
|
||||
if (mayContainAlias(db, v, alias_v)) {
|
||||
aliased_vs.emplace_back(alias_v);
|
||||
refined_aliases.emplace_back(alias_v);
|
||||
}
|
||||
}
|
||||
// for all the values in the alias set,
|
||||
// we set them "alive"
|
||||
for (auto* aliased_v : aliased_vs) {
|
||||
for (auto* aliased_v : refined_aliases) {
|
||||
add_live_value_fn(aliased_v);
|
||||
for (const auto& u : aliased_v->uses()) {
|
||||
const auto* node = u.user;
|
||||
// track deps of the aliased values is if they
|
||||
// are our own
|
||||
live_values.at(v).insert(node);
|
||||
live_nodes[node].insert(v);
|
||||
live_values_use_chain.at(v).insert(node);
|
||||
live_nodes_def_chain[node].insert(v);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto traverse_node_fn = [&](const Node* node,
|
||||
std::vector<const Value*>& dead) {
|
||||
if (live_nodes.count(node)) {
|
||||
for (const auto* v : live_nodes.at(node)) {
|
||||
live_values.at(v).erase(node);
|
||||
if (!live_values.at(v).size()) {
|
||||
if (live_nodes_def_chain.count(node)) {
|
||||
for (const auto* v : live_nodes_def_chain.at(node)) {
|
||||
live_values_use_chain.at(v).erase(node);
|
||||
if (!live_values_use_chain.at(v).size()) {
|
||||
dead.emplace_back(v);
|
||||
}
|
||||
}
|
||||
|
|
@ -233,11 +245,11 @@ LivenessInformation GetLivenessInformation(
|
|||
std::vector<const Value*> dead;
|
||||
traverse_node_fn(node, dead);
|
||||
for (const auto* dead_value : dead) {
|
||||
live_values.erase(dead_value);
|
||||
live_values_use_chain.erase(dead_value);
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& v : live_values) {
|
||||
for (const auto& v : live_values_use_chain) {
|
||||
TORCH_CHECK(always_alive.count(v.first));
|
||||
}
|
||||
|
||||
|
|
@ -255,16 +267,16 @@ LivenessInformation GetLivenessInformation(
|
|||
return std::make_pair(liveness_map, always_alive);
|
||||
}
|
||||
|
||||
// Implementation specific pruning of values
|
||||
// from "optimzable" set. GetLivenessInformation and FindSameStorageValues
|
||||
// work with any graph, but we prune out values
|
||||
// that aren't produced by "_out" variants here.
|
||||
// Collect the set of Values that are candidates for memory planning:
|
||||
// - Values that are used in in-place operators (i.e., _out variants), and
|
||||
// - excluding those that are either inputs or outputs of
|
||||
// non in-place operators
|
||||
//
|
||||
// Returns
|
||||
// first: Values that can be optimized
|
||||
// first: Values that are candidates for memory planning
|
||||
// second: A deterministc order of all values
|
||||
std::pair<std::vector<const Value*>, std::vector<const Value*>>
|
||||
GetOptimizableValues(const std::shared_ptr<torch::jit::Graph>& graph) {
|
||||
GetMemoryPlanningCandidates(const std::shared_ptr<torch::jit::Graph>& graph) {
|
||||
// for determinism
|
||||
std::unordered_set<const Value*> seen_values;
|
||||
std::vector<const Value*> all_values;
|
||||
|
|
@ -334,7 +346,7 @@ GetOptimizableValues(const std::shared_ptr<torch::jit::Graph>& graph) {
|
|||
// NB: This is a deterministic implementation, which makes it easier to tune
|
||||
// and debug.
|
||||
std::unordered_map<const Value*, std::vector<const Value*>>
|
||||
FindSameStorageValues(
|
||||
GenerateSameStorageValues(
|
||||
const LivenessInformation& lm,
|
||||
const std::pair<std::vector<const Value*>, std::vector<const Value*>>&
|
||||
optimizable,
|
||||
|
|
@ -399,33 +411,44 @@ FindSameStorageValues(
|
|||
// to preserve determinism
|
||||
std::vector<const Value*> seen;
|
||||
|
||||
for (const auto* v : optimizable_values) {
|
||||
if (always_alive.count(v)) {
|
||||
continue;
|
||||
}
|
||||
// get values that are live during the lifetime of v
|
||||
std::set<const Value*> live;
|
||||
auto compute_liveset_fn =
|
||||
[&always_alive, &alive_during, &same_storage_values](
|
||||
std::set<const Value*>& live, const Value* v) {
|
||||
for (const auto* sv : same_storage_values.at(v)) {
|
||||
const auto& l = alive_during.count(sv) ? alive_during.at(sv)
|
||||
: std::set<const Value*>{};
|
||||
live.insert(l.begin(), l.end());
|
||||
}
|
||||
live.insert(always_alive.begin(), always_alive.end());
|
||||
};
|
||||
|
||||
for (const auto* s : seen) {
|
||||
// check if any values in this set of same_storage_values
|
||||
// are alive at the time of v
|
||||
// effectively finding | set_intersection(live, set_of_shared(s)) | > 0
|
||||
bool intersects = false;
|
||||
for (const auto* candidate_v : same_storage_values.at(s)) {
|
||||
if (live.count(candidate_v)) {
|
||||
intersects = true;
|
||||
// check if same_storage_values[s] intersects with live
|
||||
auto intersect_fn = [&same_storage_values](
|
||||
std::set<const Value*>& live, const Value* s) {
|
||||
bool intersect = false;
|
||||
for (const auto* v : same_storage_values.at(s)) {
|
||||
if (live.count(v)) {
|
||||
intersect = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// we can share memory if there's no overlap
|
||||
if (!intersects) {
|
||||
return intersect;
|
||||
};
|
||||
|
||||
for (const auto* v : optimizable_values) {
|
||||
if (always_alive.count(v)) {
|
||||
continue;
|
||||
}
|
||||
// get values that are live during the lifetime of v
|
||||
std::set<const Value*> live;
|
||||
compute_liveset_fn(live, v);
|
||||
for (const auto* s : seen) {
|
||||
// if live(same_storage_values[v]) and same_storage_values[s]
|
||||
// do not overlap, then s and v can share the same storage
|
||||
if (!intersect_fn(live, s)) {
|
||||
share_storage_fn(v, s);
|
||||
// since s is added to same_storage_values[v], live needs
|
||||
// to be recomputed, so bail out here
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
@ -556,11 +579,12 @@ StaticModule::StaticModule(
|
|||
auto lm = GetLivenessInformation(graph_, alias_db);
|
||||
external_values_ = lm.second;
|
||||
if (opts_.optimize_memory) {
|
||||
auto values = GetOptimizableValues(graph_);
|
||||
auto values = GetMemoryPlanningCandidates(graph_);
|
||||
if (!opts_.enable_out_variant) {
|
||||
values.first = {};
|
||||
}
|
||||
value_to_same_storage_values_ = FindSameStorageValues(lm, values, alias_db);
|
||||
value_to_same_storage_values_ =
|
||||
GenerateSameStorageValues(lm, values, alias_db);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -16,8 +16,8 @@ struct TORCH_API StaticModuleOptions {
|
|||
bool cleanup_activations{true};
|
||||
bool enable_out_variant{true};
|
||||
bool optimize_memory{true};
|
||||
bool optimize_output_memory{
|
||||
false}; // to enable MemoryPlanner on output tensors
|
||||
// to enable MemoryPlanner on output tensors
|
||||
bool optimize_output_memory{false};
|
||||
};
|
||||
|
||||
/// The static runime supports two execution modes.
|
||||
|
|
@ -81,7 +81,6 @@ class TORCH_API StaticModule {
|
|||
typedef enum {
|
||||
CONSTANT_VALUE = -2, // VALUE nodes defined by prim::Constant
|
||||
INPUT_VALUE = -1, // VALUE nodes representing graph inputs
|
||||
OTHER_VALUE = 0 // other VALUE nodes (use non-negative index)
|
||||
} VALUE_KIND;
|
||||
|
||||
private:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user