[SR] Add utility class to determine tensor ranges (#68284)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68284

Add a new class `ManagedTensorRanges` that determines when manage tensors can be made available for re-use. This class provides a method `availableTensors(Node* node)` that returns a vector of `Value*` (corresponding to managed tensors) that are not used (either directly or through any alias) after `node`.

Test Plan: New unit tests: `buck test caffe2/benchmarks/static_runtime:static_runtime_cpptest`

Reviewed By: swolchok

Differential Revision: D32397207

fbshipit-source-id: fb0d9a23f13abf6f2207e3d7266384966f477fc6
This commit is contained in:
Mike Iovine 2021-11-19 13:09:24 -08:00 committed by Facebook GitHub Bot
parent a6d862c50a
commit ee4cfaa286
3 changed files with 382 additions and 0 deletions

View File

@ -984,3 +984,178 @@ TEST(ProcessedFunction, ProcessedFunction) {
EXPECT_EQ(transpose_fn.kind(), ProcessedFunction::Kind::kNativeFunction);
EXPECT_FALSE(transpose_fn.checkMemoryOverlap());
}
TEST(ManagedTensorRanges, NoAliases) {
const std::string src = R"IR(
graph(%x : Tensor):
%y : Tensor = aten::mul(%x, %x)
%z : Tensor = aten::mul(%y, %x)
%output : Tensor = aten::mul(%z, %z)
return (%output)
)IR";
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(src, graph.get(), vmap);
auto* y = vmap["y"];
auto* z = vmap["z"];
FastSet<const Value*> managed_tensors = {y, z};
ManagedTensorRanges ranges(graph, managed_tensors);
std::vector<Node*> nodes(
graph->block()->nodes().begin(), graph->block()->nodes().end());
ASSERT_EQ(nodes.size(), 3);
EXPECT_FALSE(ranges.nodeFreesManagedTensors(nodes[0]));
EXPECT_TRUE(ranges.nodeFreesManagedTensors(nodes[1]));
EXPECT_EQ(
ranges.availableTensorsAfterNode(nodes[1]), std::vector<const Value*>{y});
EXPECT_TRUE(ranges.nodeFreesManagedTensors(nodes[2]));
EXPECT_EQ(
ranges.availableTensorsAfterNode(nodes[2]), std::vector<const Value*>{z});
}
TEST(ManagedTensorRanges, AliasExtendingLifetimes) {
const std::string src = R"IR(
graph(%x : Tensor):
%y : Tensor = aten::mul(%x, %x)
%y_size : int[] = aten::size(%y)
%z1 : Tensor = aten::mul(%y, %y)
%y_alias : Tensor = aten::view(%y, %y_size)
%z2 : Tensor = aten::mul(%y_alias, %y_alias)
%output : Tensor = aten::mul(%z1, %z2)
return (%output)
)IR";
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(src, graph.get(), vmap);
auto* y = vmap["y"];
auto* z1 = vmap["z1"];
auto* z2 = vmap["z2"];
FastSet<const Value*> managed_tensors = {y, z1, z2};
ManagedTensorRanges ranges(graph, managed_tensors);
std::vector<Node*> nodes(
graph->block()->nodes().begin(), graph->block()->nodes().end());
ASSERT_EQ(nodes.size(), 6);
for (const auto i : c10::irange(4)) {
EXPECT_FALSE(ranges.nodeFreesManagedTensors(nodes[i]));
}
EXPECT_TRUE(ranges.nodeFreesManagedTensors(nodes[4]));
EXPECT_EQ(
ranges.availableTensorsAfterNode(nodes[4]), std::vector<const Value*>{y});
EXPECT_TRUE(ranges.nodeFreesManagedTensors(nodes[5]));
const auto& available_after_5 = ranges.availableTensorsAfterNode(nodes[5]);
// We don't care about the order, so convert to set. But make sure
// there are no duplicates.
FastSet<const Value*> available_after_5_set(
available_after_5.begin(), available_after_5.end());
EXPECT_EQ(available_after_5_set.size(), available_after_5.size());
EXPECT_EQ(available_after_5_set, FastSet<const Value*>({z1, z2}));
}
TEST(ManagedTensorRanges, UnusedTensor) {
const std::string src = R"IR(
graph(%x : Tensor):
%y : Tensor = aten::mul(%x, %x)
%z : Tensor = aten::mul(%x, %x)
%output : Tensor = aten::mul(%z, %z)
return (%output)
)IR";
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(src, graph.get(), vmap);
auto* y = vmap["y"];
auto* z = vmap["z"];
ManagedTensorRanges ranges(graph, {y});
EXPECT_TRUE(ranges.isUnusedValue(y));
EXPECT_FALSE(ranges.isUnusedValue(z));
}
TEST(ManagedTensorRanges, LifetimeOverlap) {
const std::string src = R"IR(
graph(%a : Tensor):
%b : Tensor = aten::mul(%a, %a)
%c : Tensor = aten::mul(%b, %b)
%c_size : int[] = aten::size(%c)
%c_alias : Tensor = aten::view(%c, %c_size)
%d : Tensor = aten::mul(%a, %a)
%e : Tensor = aten::mul(%c_alias, %c_alias)
%output : Tensor = aten::mul(%e, %e)
return (%output)
)IR";
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(src, graph.get(), vmap);
auto* b = vmap["b"];
auto* c = vmap["c"];
auto* d = vmap["d"];
auto* e = vmap["e"];
ManagedTensorRanges ranges(graph, {b, c, d, e});
const std::vector<std::pair<Value*, Value*>> overlapping_values{
{b, c}, {c, d}, {c, e}};
const std::vector<std::pair<Value*, Value*>> disjoint_values{{b, d}, {b, e}};
for (const auto& values : overlapping_values) {
EXPECT_TRUE(ranges.lifetimesOverlap(values.first, values.second));
EXPECT_TRUE(ranges.lifetimesOverlap(values.second, values.first));
}
for (const auto& values : disjoint_values) {
EXPECT_FALSE(ranges.lifetimesOverlap(values.first, values.second));
EXPECT_FALSE(ranges.lifetimesOverlap(values.second, values.first));
}
}
TEST(ManagedTensorRanges, OverlappingLifetimesContainers) {
const std::string src = R"IR(
graph(%a : Tensor):
%b : Tensor = aten::mul(%a, %a)
%c : Tensor = aten::mul(%b, %b)
%tuple : (Tensor, Tensor) = prim::TupleConstruct(%b, %c)
%b_alias : Tensor, %c_alias : Tensor = prim::TupleUnpack(%tuple)
%d : Tensor = aten::mul(%b_alias, %c_alias)
%output : Tensor = aten::mul(%d, %d)
return (%output)
)IR";
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(src, graph.get(), vmap);
auto* b = vmap["b"];
auto* c = vmap["c"];
auto* d = vmap["d"];
ManagedTensorRanges ranges(graph, {b, c, d});
EXPECT_TRUE(ranges.lifetimesOverlap(b, c));
EXPECT_TRUE(ranges.lifetimesOverlap(b, d));
EXPECT_TRUE(ranges.lifetimesOverlap(c, d));
}
TEST(ManagedTensorRanges, OverlappingLifetimesOutputs) {
const std::string src = R"IR(
graph(%a : Tensor):
%output : Tensor = aten::mul(%a, %a)
%b : Tensor = aten::mul(%a, %a)
return (%output)
)IR";
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(src, graph.get(), vmap);
auto* b = vmap["b"];
auto* output = vmap["output"];
ManagedTensorRanges ranges(graph, {b, output});
EXPECT_TRUE(ranges.lifetimesOverlap(b, output));
}

View File

@ -624,6 +624,8 @@ void ValueGroup::init(
}
}
namespace {
bool containTensorsOnly(at::ArrayRef<Value*> values) {
// return true only if all outputs are tensors
return std::all_of(values.begin(), values.end(), [](const Value* value) {
@ -631,6 +633,164 @@ bool containTensorsOnly(at::ArrayRef<Value*> values) {
});
}
bool mayContainAlias(const Value* v1, const Value* v2, const AliasDb& db) {
// AliasDb is not const-correct here, so we have to const_cast
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
return db.mayContainAlias(const_cast<Value*>(v1), const_cast<Value*>(v2));
}
bool isPureFunction(const Node* node) {
auto* schema = node->maybeSchema();
return schema &&
schema->aliasAnalysis() == c10::AliasAnalysisKind::PURE_FUNCTION;
}
} // namespace
ManagedTensorRanges::ManagedTensorRanges(
const std::shared_ptr<Graph>& graph,
const FastSet<const Value*>& managed_tensor_values) {
AliasDb alias_db(graph);
const std::vector<Node*> nodes(graph->nodes().begin(), graph->nodes().end());
const FastSet<const Value*> graph_inputs(
graph->inputs().begin(), graph->inputs().end());
auto isUntrackedValue = [&alias_db, &graph_inputs](const Value* value) {
return !alias_db.isMutableType(value) ||
graph_inputs.find(value) != graph_inputs.end();
};
const auto num_nodes = nodes.size();
for (const auto i : c10::irange(num_nodes)) {
auto* node = nodes[i];
for (auto* input : node->inputs()) {
auto* lifetime = getLifetime(input);
if (!lifetime) {
DCHECK(isUntrackedValue(input));
continue;
}
DCHECK(lifetime->end <= i);
lifetime->end = i;
}
for (auto* output : node->outputs()) {
if (!alias_db.isMutableType(output)) {
continue;
}
value_lifetimes_.emplace(output, Lifetime(i, i));
}
}
for (auto* graph_output : graph->outputs()) {
auto* lifetime = getLifetime(graph_output);
if (!lifetime) {
DCHECK(isUntrackedValue(graph_output));
continue;
}
lifetime->end = num_nodes;
}
// Handle aliases. Aliases may extend a Value*'s lifetime. If a node
// has an input and output that may alias each other, set the input's
// lifetime end to max(input.lifetime_end, output.lifetime_end). Iterate
// backwards to handle chains of aliases.
for (const auto* node : graph->nodes().reverse()) {
if (isPureFunction(node)) {
// If the node is a pure function, it doesn't create any aliases,
// so we can safely skip it.
continue;
}
auto inputs = collectValuesWithTrackedLifetimes(node->inputs());
auto outputs = collectValuesWithTrackedLifetimes(node->outputs());
for (auto* input : inputs) {
auto* input_lifetime = getLifetime(input);
DCHECK(input_lifetime != nullptr);
for (auto* output : outputs) {
if (mayContainAlias(input, output, alias_db)) {
auto* output_lifetime = getLifetime(output);
DCHECK(output_lifetime != nullptr);
input_lifetime->end =
std::max(output_lifetime->end, input_lifetime->end);
}
}
}
}
for (auto* managed_tensor : managed_tensor_values) {
auto* lifetime = getLifetime(managed_tensor);
DCHECK(lifetime && lifetime->end <= num_nodes);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
Node* freeing_node;
if (lifetime->end == num_nodes) {
freeing_node = graph->return_node();
} else {
freeing_node = nodes[lifetime->end];
}
node_to_newly_free_tensors_[freeing_node].emplace_back(managed_tensor);
}
}
bool ManagedTensorRanges::isUnusedValue(const Value* value) const {
auto* lifetime = getLifetime(value);
return lifetime && lifetime->start == lifetime->end;
}
bool ManagedTensorRanges::nodeFreesManagedTensors(Node* node) const {
auto it = node_to_newly_free_tensors_.find(node);
return it != node_to_newly_free_tensors_.end() && !it->second.empty();
}
const std::vector<const Value*>& ManagedTensorRanges::availableTensorsAfterNode(
Node* node) const {
return node_to_newly_free_tensors_.at(node);
}
bool ManagedTensorRanges::lifetimesOverlap(const Value* v1, const Value* v2)
const {
const auto* v1_lifetime = getLifetime(v1);
const auto* v2_lifetime = getLifetime(v2);
if (!v1_lifetime || !v2_lifetime) {
return false;
}
if (v1_lifetime->start < v2_lifetime->start) {
return v1_lifetime->end >= v2_lifetime->start;
}
return v2_lifetime->end >= v1_lifetime->start;
}
const ManagedTensorRanges::Lifetime* ManagedTensorRanges::getLifetime(
const Value* value) const {
auto it = value_lifetimes_.find(value);
if (it != value_lifetimes_.end()) {
return &it->second;
}
return nullptr;
}
ManagedTensorRanges::Lifetime* ManagedTensorRanges::getLifetime(
const Value* value) {
// const_cast is safe here, this is just a way to avoid code duplication
// between the const/non-const versions of getLifetime.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const auto* const_this = const_cast<const ManagedTensorRanges*>(this);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
return const_cast<ManagedTensorRanges::Lifetime*>(
const_this->getLifetime(value));
}
std::vector<const Value*> ManagedTensorRanges::
collectValuesWithTrackedLifetimes(at::ArrayRef<const Value*> values) {
std::vector<const Value*> mutable_values;
mutable_values.reserve(values.size());
std::copy_if(
values.begin(),
values.end(),
std::back_inserter(mutable_values),
[this](const Value* value) { return getLifetime(value) != nullptr; });
return mutable_values;
}
StaticModule::StaticModule(
std::shared_ptr<torch::jit::Graph> g,
const StaticModuleOptions& opts)

View File

@ -95,6 +95,53 @@ class ValueGroup {
FastSet<const Value*> external_aliases_;
};
class TORCH_API ManagedTensorRanges {
public:
ManagedTensorRanges() = default;
ManagedTensorRanges(
const std::shared_ptr<Graph>& graph,
const FastSet<const Value*>& managed_tensor_values);
// If true, then this node is the last use of at least one
// managed tensor. availableTensorsAfterNode(node) will return a vector
// of the managed tensors that are available for re-use
// in the nodes following this one.
bool nodeFreesManagedTensors(Node* node) const;
const std::vector<const Value*>& availableTensorsAfterNode(Node* node) const;
// True if the value has a tracked lifetime and lifetime.start ==
// lifetime.end. "Unused" does not imply "unmanaged" -
// managed tensors can be unused if they're not passed to any ops!
bool isUnusedValue(const Value* value) const;
// For testing. True if v1 and v2 are both mutable types and have lifetimes
// that overlap.
bool lifetimesOverlap(const Value* v1, const Value* v2) const;
private:
struct Lifetime {
Lifetime(size_t start_, size_t end_) : start(start_), end(end_) {}
size_t start;
size_t end;
};
// Returns nullptr if we are not tracking the lifetime of value
Lifetime* getLifetime(const Value* value);
const Lifetime* getLifetime(const Value* value) const;
// Collect all values in the input that have tracked lifetimes.
// A value's lifetime may not be tracked if it is a graph input
// or immutable type (containers with at least one mutable
// type are mutable)
std::vector<const Value*> collectValuesWithTrackedLifetimes(
at::ArrayRef<const Value*> values);
// Maps Node* to the set of managed tensors that are now available
// for re-use after this node.
FastMap<Node*, std::vector<const Value*>> node_to_newly_free_tensors_{};
// Maps each Value* to its lifetime (start node index, end node index)
FastMap<const Value*, Lifetime> value_lifetimes_{};
};
struct TORCH_API StaticModuleOptions {
// to batch allocate (deallocate) tensor storage for all non-escaping
// temporary tensors