[StaticRuntime] Fix a bug that memory planner ignores subblocks (#146728) (#146855)

Summary:

When Static Runtime graph node has sub-blocks, the memory planner does not consider sub-blocks' inputs as a node's input in memory planner. As the result, such nodes' inputs' lifetime is incorrect and corresponding tensor memory is released earlier than required and causes errors.

Differential Revision: D69195886

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146855
Approved by: https://github.com/swolchok
This commit is contained in:
Zhou Fang 2025-02-11 13:59:54 +00:00 committed by PyTorch MergeBot
parent 15635b14ce
commit fc5913b6bf
3 changed files with 75 additions and 8 deletions

View File

@ -1274,6 +1274,59 @@ TEST(ManagedTensorRanges, OverlappingLifetimesOutputs) {
EXPECT_TRUE(ranges.lifetimesOverlap(b, output));
}
TEST(ManagedTensorRanges, LifetimeIncludeSubBlockInputs) {
const std::string src_plain = R"IR(
graph(%cond : bool, %a : Tensor):
%b : Tensor = aten::mul(%a, %a)
%output : bool = prim::If(%cond)
block0():
-> (%a)
block1():
%c : Tensor = aten::mul(%b, %a)
-> (%c)
return (%output)
)IR";
const std::string src_recursive = R"IR(
graph(%cond : bool, %a : Tensor):
%b : Tensor = aten::mul(%a, %a)
%output : bool = prim::If(%cond)
block0():
-> (%a)
block1():
%outputblock1 : bool = prim::If(%cond)
block0():
-> (%a)
block1():
%c : Tensor = aten::mul(%b, %a)
-> (%c)
-> (%outputblock1)
return (%output)
)IR";
for (const auto& src : {src_plain, src_recursive}) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(src, graph.get(), vmap);
auto* b = vmap["b"];
FastSet<const Value*> managed_tensors = {b};
AliasDb alias_db(graph);
auto ranges = ManagedTensorRanges(*graph->block(), alias_db, managed_tensors);
std::vector<Node*> nodes(
graph->block()->nodes().begin(), graph->block()->nodes().end());
ASSERT_EQ(nodes.size(), 2);
EXPECT_FALSE(ranges.nodeFreesManagedTensors(nodes[0]));
EXPECT_TRUE(ranges.nodeFreesManagedTensors(nodes[1]));
EXPECT_EQ(
ranges.availableTensorValuesAfterNode(nodes[1]),
std::vector<const Value*>{b});
}
}
namespace {
// For checking the correctness of assignStorageToManageTensors, the following

View File

@ -395,6 +395,25 @@ bool isPureFunction(const Node* node) {
} // namespace
void ManagedTensorRanges::extendLifetime(Value* input, size_t new_end) {
auto* lifetime = getLifetime(input);
if (lifetime) {
TORCH_DCHECK_LE(lifetime->end, new_end);
lifetime->end = new_end;
}
}
void ManagedTensorRanges::extendInputLifetime(Node* node, size_t new_end) {
for (auto* input : node->inputs()) {
extendLifetime(input, new_end);
}
for (auto* subblock : node->blocks()) {
for (auto* subnode : subblock->nodes()) {
extendInputLifetime(subnode, new_end);
}
}
}
ManagedTensorRanges::ManagedTensorRanges(
Block& block,
const AliasDb& alias_db,
@ -404,14 +423,7 @@ ManagedTensorRanges::ManagedTensorRanges(
const auto num_nodes = static_cast<uint32_t>(nodes.size());
for (const auto i : c10::irange(num_nodes)) {
auto* node = nodes[i];
for (auto* input : node->inputs()) {
auto* lifetime = getLifetime(input);
if (!lifetime) {
continue;
}
DCHECK(lifetime->end <= i);
lifetime->end = i;
}
extendInputLifetime(node, i);
for (auto* output : node->outputs()) {
if (!alias_db.isMutableType(output)) {
continue;

View File

@ -137,6 +137,8 @@ class TORCH_API ManagedTensorRanges {
// type are mutable)
std::vector<const Value*> collectValuesWithTrackedLifetimes(
at::ArrayRef<const Value*> values);
void extendLifetime(Value* input, size_t new_end);
void extendInputLifetime(Node* node, size_t new_end);
// Maps Node* to the set of managed tensors that are now available
// for re-use after this node.