[nativert] make runtime const folding aware of run_const_graph (#160760)

Summary: it's possible that we have foldable nodes that use things that will be folded by run_const_graph

Test Plan:
CI

Rollback Plan:

Differential Revision: D80355542

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160760
Approved by: https://github.com/SherlockNoMad
This commit is contained in:
Dylan Maloy 2025-08-21 05:22:00 +00:00 committed by PyTorch MergeBot
parent 9d18bf01b1
commit 2f50ae7d20
2 changed files with 19 additions and 6 deletions

View File

@ -41,6 +41,8 @@ void ConstantFolder::unlinkConstants(
const auto* input = &*graph_.nodes().begin();
const auto* output = &*graph_.nodes().end();
c10::FastSet<const Node*> run_const_graph_nodes;
{ // ignore prim.Input and prim.Output
auto ct = 0;
for (auto& n : graph_.nodes()) {
@ -49,6 +51,19 @@ void ConstantFolder::unlinkConstants(
}
nodeDynInputs[&n] = n.numInputs();
nodeKernels[&n] = &kernels[++ct];
if (n.target() == "torch.ops.higher_order.run_const_graph") {
run_const_graph_nodes.insert(&n);
}
}
}
for (const auto* run_const_graph_node : run_const_graph_nodes) {
for (auto* user : run_const_graph_node->users()) {
if (user == input || user == output) {
continue;
}
nodeDynInputs[user] -= 1;
}
}

View File

@ -115,13 +115,14 @@ void Executor::maybeRunConstantFolding(
weights->updateFoldedConst(value->name(), outputs.at(idx));
}
}
// runtime constant folding after the run_const_graph HOPs, if applicable
if (constantFolder_.has_value()) {
constantFolder_->evaluate(*weights);
}
}
void Executor::processWeights(const std::shared_ptr<Weights>& weights) {
maybeRunConstantFolding(weights);
if (constantFolder_.has_value()) {
constantFolder_->evaluate(*weights);
}
for (auto& delegateExecutor : delegateExecutors_) {
delegateExecutor->processWeights(weights);
}
@ -129,9 +130,6 @@ void Executor::processWeights(const std::shared_ptr<Weights>& weights) {
void Executor::initWeights(const std::shared_ptr<Weights>& weights) {
maybeRunConstantFolding(weights);
if (constantFolder_.has_value()) {
constantFolder_->evaluate(*weights);
}
weights_.withLock([&](auto& w) { w = std::move(weights); });