mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
More name refactoring of memory planning codes to make it more readable (#54272)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54272 Test Plan: Imported from OSS Reviewed By: bwasti Differential Revision: D27233881 fbshipit-source-id: f257f16ac0684df055961e539f17d002cb8f1bfe
This commit is contained in:
parent
1ceb90405b
commit
fe2c1268b7
|
|
@ -125,35 +125,38 @@ using LivenessInformation = std::pair<
|
||||||
LivenessInformation GetLivenessInformation(
|
LivenessInformation GetLivenessInformation(
|
||||||
const std::shared_ptr<torch::jit::Graph>& graph,
|
const std::shared_ptr<torch::jit::Graph>& graph,
|
||||||
AliasDb& db) {
|
AliasDb& db) {
|
||||||
|
// map a Value to a set of Values that overlap live-ranges with the Value's
|
||||||
std::unordered_map<const Value*, std::set<const Value*>> liveness_map;
|
std::unordered_map<const Value*, std::set<const Value*>> liveness_map;
|
||||||
|
// a set of Values whose live-range exceed current inference
|
||||||
std::unordered_set<const Value*> always_alive;
|
std::unordered_set<const Value*> always_alive;
|
||||||
|
|
||||||
|
// map Values to its creation order in graph (Note: only traverse top-level
|
||||||
|
// nodes such that nodes under control-flows are represented by top-level
|
||||||
|
// block nodes)
|
||||||
std::vector<const Value*> values_in_creation_order;
|
std::vector<const Value*> values_in_creation_order;
|
||||||
std::unordered_map<const Value*, size_t> values_in_creation_order_idx;
|
std::unordered_map<const Value*, size_t> values_to_idx_in_creation_order;
|
||||||
for (const auto* node : graph->nodes()) {
|
for (const auto* node : graph->nodes()) {
|
||||||
for (const auto* v : node->outputs()) {
|
for (const auto* v : node->outputs()) {
|
||||||
values_in_creation_order_idx[v] = values_in_creation_order.size();
|
values_to_idx_in_creation_order[v] = values_in_creation_order.size();
|
||||||
values_in_creation_order.emplace_back(v);
|
values_in_creation_order.emplace_back(v);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// maps values to any nodes that consume or produce them
|
// presence of a Value in live_values_use_chain means the Value alive
|
||||||
//
|
// Value mapped to set of Nodes that may use the Value (i.e., use-chain of
|
||||||
// updated as we traverse the graph. the presence of a key in `live_values`
|
// Value)
|
||||||
// means that the value is currently alive.
|
std::unordered_map<const Value*, std::set<const Node*>> live_values_use_chain;
|
||||||
//
|
// Node mapped to set of Values that the Node may use (i.e., def-chain of node
|
||||||
// invariant: set.size() > 0
|
// inputs)
|
||||||
std::unordered_map<const Value*, std::set<const Node*>> live_values;
|
std::unordered_map<const Node*, std::set<const Value*>> live_nodes_def_chain;
|
||||||
std::unordered_map<const Node*, std::set<const Value*>> live_nodes;
|
|
||||||
|
|
||||||
// inputs and outputs are marked permanently alive
|
// mark inputs, constants, outputs as always_alive
|
||||||
for (const auto* input : graph->inputs()) {
|
for (const auto* input : graph->inputs()) {
|
||||||
always_alive.insert(input);
|
always_alive.insert(input);
|
||||||
}
|
}
|
||||||
for (const auto* output : graph->outputs()) {
|
for (const auto* output : graph->outputs()) {
|
||||||
always_alive.insert(output);
|
always_alive.insert(output);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const auto* node : graph->nodes()) {
|
for (const auto* node : graph->nodes()) {
|
||||||
if (node->kind() == prim::Constant) {
|
if (node->kind() == prim::Constant) {
|
||||||
for (const auto* output : node->outputs()) {
|
for (const auto* output : node->outputs()) {
|
||||||
|
|
@ -162,13 +165,14 @@ LivenessInformation GetLivenessInformation(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// add v to the current liveness_map
|
||||||
std::function<void(const Value* v)> add_live_value_fn = [&](const Value* v) {
|
std::function<void(const Value* v)> add_live_value_fn = [&](const Value* v) {
|
||||||
if (liveness_map.count(v)) {
|
if (liveness_map.count(v)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
liveness_map[v] = {};
|
liveness_map[v] = {};
|
||||||
|
|
||||||
for (const auto& live_v : live_values) {
|
for (const auto& live_v : live_values_use_chain) {
|
||||||
liveness_map.at(v).insert(live_v.first);
|
liveness_map.at(v).insert(live_v.first);
|
||||||
liveness_map.at(live_v.first).insert(v);
|
liveness_map.at(live_v.first).insert(v);
|
||||||
}
|
}
|
||||||
|
|
@ -176,45 +180,53 @@ LivenessInformation GetLivenessInformation(
|
||||||
// only add values to the live set if they
|
// only add values to the live set if they
|
||||||
// have deps, otherwise they die immediately
|
// have deps, otherwise they die immediately
|
||||||
if (v->uses().size()) {
|
if (v->uses().size()) {
|
||||||
live_values[v] = {};
|
live_values_use_chain[v] = {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// record the relationship between v (Value) and its uses (Node)
|
||||||
for (const auto& u : v->uses()) {
|
for (const auto& u : v->uses()) {
|
||||||
const auto* node = u.user;
|
const auto* node = u.user;
|
||||||
// track deps of this value
|
live_values_use_chain.at(v).insert(node);
|
||||||
live_values.at(v).insert(node);
|
live_nodes_def_chain[node].insert(v);
|
||||||
live_nodes[node].insert(v);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// values created after this one that alias it
|
// FIXME(penguin): the following alias refinement seems to assume
|
||||||
std::vector<const Value*> aliased_vs;
|
// that `v` refers to a new tensor created by the node that defines
|
||||||
auto idx = values_in_creation_order_idx[v];
|
// v, thus other Values "before" the node that defines `v` cannot
|
||||||
|
// possibly be aliased to `v`.
|
||||||
|
// TODO(penguin): Is it a limitation of TS alias analysis
|
||||||
|
// so that we need to do such refinement? If so, better improve
|
||||||
|
// alias analysis so that we dont need this special handling here
|
||||||
|
//
|
||||||
|
// Refine aliases of v by include only those created after v
|
||||||
|
std::vector<const Value*> refined_aliases;
|
||||||
|
auto idx = values_to_idx_in_creation_order[v];
|
||||||
for (; idx < values_in_creation_order.size(); ++idx) {
|
for (; idx < values_in_creation_order.size(); ++idx) {
|
||||||
auto* alias_v = values_in_creation_order[idx];
|
auto* alias_v = values_in_creation_order[idx];
|
||||||
if (mayContainAlias(db, v, alias_v)) {
|
if (mayContainAlias(db, v, alias_v)) {
|
||||||
aliased_vs.emplace_back(alias_v);
|
refined_aliases.emplace_back(alias_v);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// for all the values in the alias set,
|
// for all the values in the alias set,
|
||||||
// we set them "alive"
|
// we set them "alive"
|
||||||
for (auto* aliased_v : aliased_vs) {
|
for (auto* aliased_v : refined_aliases) {
|
||||||
add_live_value_fn(aliased_v);
|
add_live_value_fn(aliased_v);
|
||||||
for (const auto& u : aliased_v->uses()) {
|
for (const auto& u : aliased_v->uses()) {
|
||||||
const auto* node = u.user;
|
const auto* node = u.user;
|
||||||
// track deps of the aliased values is if they
|
// track deps of the aliased values is if they
|
||||||
// are our own
|
// are our own
|
||||||
live_values.at(v).insert(node);
|
live_values_use_chain.at(v).insert(node);
|
||||||
live_nodes[node].insert(v);
|
live_nodes_def_chain[node].insert(v);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
auto traverse_node_fn = [&](const Node* node,
|
auto traverse_node_fn = [&](const Node* node,
|
||||||
std::vector<const Value*>& dead) {
|
std::vector<const Value*>& dead) {
|
||||||
if (live_nodes.count(node)) {
|
if (live_nodes_def_chain.count(node)) {
|
||||||
for (const auto* v : live_nodes.at(node)) {
|
for (const auto* v : live_nodes_def_chain.at(node)) {
|
||||||
live_values.at(v).erase(node);
|
live_values_use_chain.at(v).erase(node);
|
||||||
if (!live_values.at(v).size()) {
|
if (!live_values_use_chain.at(v).size()) {
|
||||||
dead.emplace_back(v);
|
dead.emplace_back(v);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -233,11 +245,11 @@ LivenessInformation GetLivenessInformation(
|
||||||
std::vector<const Value*> dead;
|
std::vector<const Value*> dead;
|
||||||
traverse_node_fn(node, dead);
|
traverse_node_fn(node, dead);
|
||||||
for (const auto* dead_value : dead) {
|
for (const auto* dead_value : dead) {
|
||||||
live_values.erase(dead_value);
|
live_values_use_chain.erase(dead_value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const auto& v : live_values) {
|
for (const auto& v : live_values_use_chain) {
|
||||||
TORCH_CHECK(always_alive.count(v.first));
|
TORCH_CHECK(always_alive.count(v.first));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -255,16 +267,16 @@ LivenessInformation GetLivenessInformation(
|
||||||
return std::make_pair(liveness_map, always_alive);
|
return std::make_pair(liveness_map, always_alive);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Implementation specific pruning of values
|
// Collect the set of Values that are candidates for memory planning:
|
||||||
// from "optimzable" set. GetLivenessInformation and FindSameStorageValues
|
// - Values that are used in in-place operators (i.e., _out variants), and
|
||||||
// work with any graph, but we prune out values
|
// - excluding those that are either inputs or outputs of
|
||||||
// that aren't produced by "_out" variants here.
|
// non in-place operators
|
||||||
//
|
//
|
||||||
// Returns
|
// Returns
|
||||||
// first: Values that can be optimized
|
// first: Values that are candidates for memory planning
|
||||||
// second: A deterministc order of all values
|
// second: A deterministc order of all values
|
||||||
std::pair<std::vector<const Value*>, std::vector<const Value*>>
|
std::pair<std::vector<const Value*>, std::vector<const Value*>>
|
||||||
GetOptimizableValues(const std::shared_ptr<torch::jit::Graph>& graph) {
|
GetMemoryPlanningCandidates(const std::shared_ptr<torch::jit::Graph>& graph) {
|
||||||
// for determinism
|
// for determinism
|
||||||
std::unordered_set<const Value*> seen_values;
|
std::unordered_set<const Value*> seen_values;
|
||||||
std::vector<const Value*> all_values;
|
std::vector<const Value*> all_values;
|
||||||
|
|
@ -334,7 +346,7 @@ GetOptimizableValues(const std::shared_ptr<torch::jit::Graph>& graph) {
|
||||||
// NB: This is a deterministic implementation, which makes it easier to tune
|
// NB: This is a deterministic implementation, which makes it easier to tune
|
||||||
// and debug.
|
// and debug.
|
||||||
std::unordered_map<const Value*, std::vector<const Value*>>
|
std::unordered_map<const Value*, std::vector<const Value*>>
|
||||||
FindSameStorageValues(
|
GenerateSameStorageValues(
|
||||||
const LivenessInformation& lm,
|
const LivenessInformation& lm,
|
||||||
const std::pair<std::vector<const Value*>, std::vector<const Value*>>&
|
const std::pair<std::vector<const Value*>, std::vector<const Value*>>&
|
||||||
optimizable,
|
optimizable,
|
||||||
|
|
@ -399,33 +411,44 @@ FindSameStorageValues(
|
||||||
// to preserve determinism
|
// to preserve determinism
|
||||||
std::vector<const Value*> seen;
|
std::vector<const Value*> seen;
|
||||||
|
|
||||||
for (const auto* v : optimizable_values) {
|
auto compute_liveset_fn =
|
||||||
if (always_alive.count(v)) {
|
[&always_alive, &alive_during, &same_storage_values](
|
||||||
continue;
|
std::set<const Value*>& live, const Value* v) {
|
||||||
}
|
|
||||||
// get values that are live during the lifetime of v
|
|
||||||
std::set<const Value*> live;
|
|
||||||
for (const auto* sv : same_storage_values.at(v)) {
|
for (const auto* sv : same_storage_values.at(v)) {
|
||||||
const auto& l = alive_during.count(sv) ? alive_during.at(sv)
|
const auto& l = alive_during.count(sv) ? alive_during.at(sv)
|
||||||
: std::set<const Value*>{};
|
: std::set<const Value*>{};
|
||||||
live.insert(l.begin(), l.end());
|
live.insert(l.begin(), l.end());
|
||||||
}
|
}
|
||||||
live.insert(always_alive.begin(), always_alive.end());
|
live.insert(always_alive.begin(), always_alive.end());
|
||||||
|
};
|
||||||
|
|
||||||
for (const auto* s : seen) {
|
// check if same_storage_values[s] intersects with live
|
||||||
// check if any values in this set of same_storage_values
|
auto intersect_fn = [&same_storage_values](
|
||||||
// are alive at the time of v
|
std::set<const Value*>& live, const Value* s) {
|
||||||
// effectively finding | set_intersection(live, set_of_shared(s)) | > 0
|
bool intersect = false;
|
||||||
bool intersects = false;
|
for (const auto* v : same_storage_values.at(s)) {
|
||||||
for (const auto* candidate_v : same_storage_values.at(s)) {
|
if (live.count(v)) {
|
||||||
if (live.count(candidate_v)) {
|
intersect = true;
|
||||||
intersects = true;
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// we can share memory if there's no overlap
|
return intersect;
|
||||||
if (!intersects) {
|
};
|
||||||
|
|
||||||
|
for (const auto* v : optimizable_values) {
|
||||||
|
if (always_alive.count(v)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// get values that are live during the lifetime of v
|
||||||
|
std::set<const Value*> live;
|
||||||
|
compute_liveset_fn(live, v);
|
||||||
|
for (const auto* s : seen) {
|
||||||
|
// if live(same_storage_values[v]) and same_storage_values[s]
|
||||||
|
// do not overlap, then s and v can share the same storage
|
||||||
|
if (!intersect_fn(live, s)) {
|
||||||
share_storage_fn(v, s);
|
share_storage_fn(v, s);
|
||||||
|
// since s is added to same_storage_values[v], live needs
|
||||||
|
// to be recomputed, so bail out here
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -556,11 +579,12 @@ StaticModule::StaticModule(
|
||||||
auto lm = GetLivenessInformation(graph_, alias_db);
|
auto lm = GetLivenessInformation(graph_, alias_db);
|
||||||
external_values_ = lm.second;
|
external_values_ = lm.second;
|
||||||
if (opts_.optimize_memory) {
|
if (opts_.optimize_memory) {
|
||||||
auto values = GetOptimizableValues(graph_);
|
auto values = GetMemoryPlanningCandidates(graph_);
|
||||||
if (!opts_.enable_out_variant) {
|
if (!opts_.enable_out_variant) {
|
||||||
values.first = {};
|
values.first = {};
|
||||||
}
|
}
|
||||||
value_to_same_storage_values_ = FindSameStorageValues(lm, values, alias_db);
|
value_to_same_storage_values_ =
|
||||||
|
GenerateSameStorageValues(lm, values, alias_db);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,8 +16,8 @@ struct TORCH_API StaticModuleOptions {
|
||||||
bool cleanup_activations{true};
|
bool cleanup_activations{true};
|
||||||
bool enable_out_variant{true};
|
bool enable_out_variant{true};
|
||||||
bool optimize_memory{true};
|
bool optimize_memory{true};
|
||||||
bool optimize_output_memory{
|
// to enable MemoryPlanner on output tensors
|
||||||
false}; // to enable MemoryPlanner on output tensors
|
bool optimize_output_memory{false};
|
||||||
};
|
};
|
||||||
|
|
||||||
/// The static runime supports two execution modes.
|
/// The static runime supports two execution modes.
|
||||||
|
|
@ -81,7 +81,6 @@ class TORCH_API StaticModule {
|
||||||
typedef enum {
|
typedef enum {
|
||||||
CONSTANT_VALUE = -2, // VALUE nodes defined by prim::Constant
|
CONSTANT_VALUE = -2, // VALUE nodes defined by prim::Constant
|
||||||
INPUT_VALUE = -1, // VALUE nodes representing graph inputs
|
INPUT_VALUE = -1, // VALUE nodes representing graph inputs
|
||||||
OTHER_VALUE = 0 // other VALUE nodes (use non-negative index)
|
|
||||||
} VALUE_KIND;
|
} VALUE_KIND;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user