mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[nativert] make runtime const folding aware of run_const_graph (#160760)
Summary: it's possible that we have foldable nodes that use things that will be folded by run_const_graph Test Plan: CI Rollback Plan: Differential Revision: D80355542 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160760 Approved by: https://github.com/SherlockNoMad
This commit is contained in:
parent
9d18bf01b1
commit
2f50ae7d20
|
|
@ -41,6 +41,8 @@ void ConstantFolder::unlinkConstants(
|
|||
const auto* input = &*graph_.nodes().begin();
|
||||
const auto* output = &*graph_.nodes().end();
|
||||
|
||||
c10::FastSet<const Node*> run_const_graph_nodes;
|
||||
|
||||
{ // ignore prim.Input and prim.Output
|
||||
auto ct = 0;
|
||||
for (auto& n : graph_.nodes()) {
|
||||
|
|
@ -49,6 +51,19 @@ void ConstantFolder::unlinkConstants(
|
|||
}
|
||||
nodeDynInputs[&n] = n.numInputs();
|
||||
nodeKernels[&n] = &kernels[++ct];
|
||||
|
||||
if (n.target() == "torch.ops.higher_order.run_const_graph") {
|
||||
run_const_graph_nodes.insert(&n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto* run_const_graph_node : run_const_graph_nodes) {
|
||||
for (auto* user : run_const_graph_node->users()) {
|
||||
if (user == input || user == output) {
|
||||
continue;
|
||||
}
|
||||
nodeDynInputs[user] -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -115,13 +115,14 @@ void Executor::maybeRunConstantFolding(
|
|||
weights->updateFoldedConst(value->name(), outputs.at(idx));
|
||||
}
|
||||
}
|
||||
// runtime constant folding after the run_const_graph HOPs, if applicable
|
||||
if (constantFolder_.has_value()) {
|
||||
constantFolder_->evaluate(*weights);
|
||||
}
|
||||
}
|
||||
|
||||
void Executor::processWeights(const std::shared_ptr<Weights>& weights) {
|
||||
maybeRunConstantFolding(weights);
|
||||
if (constantFolder_.has_value()) {
|
||||
constantFolder_->evaluate(*weights);
|
||||
}
|
||||
for (auto& delegateExecutor : delegateExecutors_) {
|
||||
delegateExecutor->processWeights(weights);
|
||||
}
|
||||
|
|
@ -129,9 +130,6 @@ void Executor::processWeights(const std::shared_ptr<Weights>& weights) {
|
|||
|
||||
void Executor::initWeights(const std::shared_ptr<Weights>& weights) {
|
||||
maybeRunConstantFolding(weights);
|
||||
if (constantFolder_.has_value()) {
|
||||
constantFolder_->evaluate(*weights);
|
||||
}
|
||||
|
||||
weights_.withLock([&](auto& w) { w = std::move(weights); });
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user