mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Summary: When Static Runtime graph node has sub-blocks, the memory planner does not consider sub-blocks' inputs as a node's input in memory planner. As the result, such nodes' inputs' lifetime is incorrect and corresponding tensor memory is released earlier than required and causes errors. Differential Revision: D69195886 Pull Request resolved: https://github.com/pytorch/pytorch/pull/146855 Approved by: https://github.com/swolchok
This commit is contained in:
parent
15635b14ce
commit
fc5913b6bf
|
|
@ -1274,6 +1274,59 @@ TEST(ManagedTensorRanges, OverlappingLifetimesOutputs) {
|
|||
EXPECT_TRUE(ranges.lifetimesOverlap(b, output));
|
||||
}
|
||||
|
||||
TEST(ManagedTensorRanges, LifetimeIncludeSubBlockInputs) {
|
||||
const std::string src_plain = R"IR(
|
||||
graph(%cond : bool, %a : Tensor):
|
||||
%b : Tensor = aten::mul(%a, %a)
|
||||
%output : bool = prim::If(%cond)
|
||||
block0():
|
||||
-> (%a)
|
||||
block1():
|
||||
%c : Tensor = aten::mul(%b, %a)
|
||||
-> (%c)
|
||||
return (%output)
|
||||
)IR";
|
||||
const std::string src_recursive = R"IR(
|
||||
graph(%cond : bool, %a : Tensor):
|
||||
%b : Tensor = aten::mul(%a, %a)
|
||||
%output : bool = prim::If(%cond)
|
||||
block0():
|
||||
-> (%a)
|
||||
block1():
|
||||
%outputblock1 : bool = prim::If(%cond)
|
||||
block0():
|
||||
-> (%a)
|
||||
block1():
|
||||
%c : Tensor = aten::mul(%b, %a)
|
||||
-> (%c)
|
||||
-> (%outputblock1)
|
||||
return (%output)
|
||||
)IR";
|
||||
|
||||
for (const auto& src : {src_plain, src_recursive}) {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
parseIR(src, graph.get(), vmap);
|
||||
|
||||
auto* b = vmap["b"];
|
||||
|
||||
FastSet<const Value*> managed_tensors = {b};
|
||||
AliasDb alias_db(graph);
|
||||
auto ranges = ManagedTensorRanges(*graph->block(), alias_db, managed_tensors);
|
||||
|
||||
std::vector<Node*> nodes(
|
||||
graph->block()->nodes().begin(), graph->block()->nodes().end());
|
||||
ASSERT_EQ(nodes.size(), 2);
|
||||
|
||||
EXPECT_FALSE(ranges.nodeFreesManagedTensors(nodes[0]));
|
||||
|
||||
EXPECT_TRUE(ranges.nodeFreesManagedTensors(nodes[1]));
|
||||
EXPECT_EQ(
|
||||
ranges.availableTensorValuesAfterNode(nodes[1]),
|
||||
std::vector<const Value*>{b});
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// For checking the correctness of assignStorageToManageTensors, the following
|
||||
|
|
|
|||
|
|
@ -395,6 +395,25 @@ bool isPureFunction(const Node* node) {
|
|||
|
||||
} // namespace
|
||||
|
||||
void ManagedTensorRanges::extendLifetime(Value* input, size_t new_end) {
|
||||
auto* lifetime = getLifetime(input);
|
||||
if (lifetime) {
|
||||
TORCH_DCHECK_LE(lifetime->end, new_end);
|
||||
lifetime->end = new_end;
|
||||
}
|
||||
}
|
||||
|
||||
void ManagedTensorRanges::extendInputLifetime(Node* node, size_t new_end) {
|
||||
for (auto* input : node->inputs()) {
|
||||
extendLifetime(input, new_end);
|
||||
}
|
||||
for (auto* subblock : node->blocks()) {
|
||||
for (auto* subnode : subblock->nodes()) {
|
||||
extendInputLifetime(subnode, new_end);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ManagedTensorRanges::ManagedTensorRanges(
|
||||
Block& block,
|
||||
const AliasDb& alias_db,
|
||||
|
|
@ -404,14 +423,7 @@ ManagedTensorRanges::ManagedTensorRanges(
|
|||
const auto num_nodes = static_cast<uint32_t>(nodes.size());
|
||||
for (const auto i : c10::irange(num_nodes)) {
|
||||
auto* node = nodes[i];
|
||||
for (auto* input : node->inputs()) {
|
||||
auto* lifetime = getLifetime(input);
|
||||
if (!lifetime) {
|
||||
continue;
|
||||
}
|
||||
DCHECK(lifetime->end <= i);
|
||||
lifetime->end = i;
|
||||
}
|
||||
extendInputLifetime(node, i);
|
||||
for (auto* output : node->outputs()) {
|
||||
if (!alias_db.isMutableType(output)) {
|
||||
continue;
|
||||
|
|
|
|||
|
|
@ -137,6 +137,8 @@ class TORCH_API ManagedTensorRanges {
|
|||
// type are mutable)
|
||||
std::vector<const Value*> collectValuesWithTrackedLifetimes(
|
||||
at::ArrayRef<const Value*> values);
|
||||
void extendLifetime(Value* input, size_t new_end);
|
||||
void extendInputLifetime(Node* node, size_t new_end);
|
||||
|
||||
// Maps Node* to the set of managed tensors that are now available
|
||||
// for re-use after this node.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user