[nativert] move execution frame to torch (#155830)

Summary: att

Test Plan:
ci

Rollback Plan:

Differential Revision: D76369008

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155830
Approved by: https://github.com/zhxchen17
This commit is contained in:
dolpm 2025-06-14 03:28:55 +00:00 committed by PyTorch MergeBot
parent a6084b71ed
commit cdfa33a328
5 changed files with 382 additions and 0 deletions

View File

@ -599,6 +599,7 @@ libtorch_nativert_sources = [
"torch/nativert/executor/DelegateExecutor.cpp",
"torch/nativert/executor/Placement.cpp",
"torch/nativert/executor/ExecutionPlanner.cpp",
"torch/nativert/executor/ExecutionFrame.cpp",
"torch/nativert/executor/PlacementUtils.cpp",
"torch/nativert/executor/Weights.cpp",
"torch/nativert/executor/memory/FunctionSchema.cpp",

View File

@ -15,6 +15,7 @@ set(NATIVERT_TEST_SRCS
${TORCH_ROOT}/torch/nativert/executor/memory/FunctionSchema.cpp
${TORCH_ROOT}/torch/nativert/executor/ExecutionPlanner.cpp
${TORCH_ROOT}/torch/nativert/detail/ITree.cpp
${TORCH_ROOT}/torch/nativert/executor/ExecutionFrame.cpp
)
add_executable(test_nativert

View File

@ -0,0 +1,96 @@
#include <gtest/gtest.h>
#include <torch/nativert/executor/ExecutionFrame.h>
namespace torch::nativert {
TEST(ExecutionFrameTest, CreateFrame) {
auto graph = stringToGraph(R"(
graph(%x, %y):
%a = foo(a=%x, b=%y)
%b = foo1(a=%x, b=%y)
%c = foo2(c=%a, d=%b)
return(%c)
)");
auto frame = ExecutionFrame(*graph);
for (auto* v : graph->values()) {
frame.setIValue(v->id(), c10::IValue(at::tensor({v->id()}, at::kInt)));
auto& frame_v = frame.getIValue(v->id());
EXPECT_EQ(frame_v.tagKind(), "Tensor");
}
auto outputs = frame.tryMoveUserOutputs();
EXPECT_EQ(outputs.size(), 1);
EXPECT_EQ(outputs[0].tagKind(), "Tensor");
EXPECT_EQ(outputs[0].toTensor().item().toInt(), graph->getValue("c")->id());
}
TEST(ExecutionFrameTest, TestSetBorrowedValue) {
auto graph = stringToGraph(R"(
graph(%x, %y):
%a = foo(a=%x, b=%y)
%b = foo1(a=%x, b=%y)
%c = foo2(c=%a, d=%b)
return(%c)
)");
auto x = c10::IValue(at::tensor({1}, at::kInt));
auto y = c10::IValue(at::tensor({2}, at::kInt));
{
auto frame = ExecutionFrame(*graph);
frame.setBorrowedIValue(
graph->getValue("x")->id(),
c10::MaybeOwnedTraits<c10::IValue>::createBorrow(x));
frame.setBorrowedIValue(
graph->getValue("y")->id(),
c10::MaybeOwnedTraits<c10::IValue>::createBorrow(y));
[[maybe_unused]] auto& w = frame.getIValue(graph->getValue("x")->id());
[[maybe_unused]] auto& z = frame.getIValue(graph->getValue("y")->id());
EXPECT_EQ(x.use_count(), 1);
EXPECT_EQ(y.use_count(), 1);
EXPECT_TRUE(c10::MaybeOwnedTraits<c10::IValue>{}.debugBorrowIsValid(
frame.getIValue(graph->getValue("x")->id())));
EXPECT_TRUE(c10::MaybeOwnedTraits<c10::IValue>{}.debugBorrowIsValid(
frame.getIValue(graph->getValue("y")->id())));
}
EXPECT_EQ(x.use_count(), 1);
EXPECT_EQ(y.use_count(), 1);
}
TEST(ExecutionFrameTest, TestPersistentValue) {
auto graph = stringToGraph(R"(
graph(%x, %y, %my_weight):
%a = foo(a=%x, b=%y)
%b = foo1(a=%x, b=%y)
%c = foo2(c=%a, d=%b)
return(%c)
)");
Weights weights(graph.get());
weights.setValue("my_weight", at::tensor({1}, at::kInt));
auto new_sig = graph->signature();
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<std::vector<std::pair<std::string, std::string>>&>(
new_sig.inputsToWeights())
.emplace_back("my_weight", "my_weight");
graph->setSignature(new_sig);
auto frame = ExecutionFrame(*graph, weights);
EXPECT_EQ(frame.weightVersion(), 0);
auto wid = graph->getValue("my_weight")->id();
EXPECT_NO_THROW(frame.getTensor(wid));
EXPECT_DEATH(frame.releaseValue(wid), "Cannot release persistent value");
}
} // namespace torch::nativert

View File

@ -0,0 +1,145 @@
#include <c10/util/Enumerate.h>
#include <c10/util/Logging.h>
#include <torch/nativert/executor/ExecutionFrame.h>
#include <torch/nativert/executor/ExecutionPlanner.h>
namespace torch::nativert {
ExecutionFrame::ExecutionFrame(const Graph& graph)
: graph_(graph),
allValues_(graph.numValues()),
persistent_(graph.numValues()),
moveable_output_mask_(graph.userOutputs().size()) {
// load constant SymInts into execution frame
for (const auto& [valueId, constSymintValue] :
graph_.getConstantSymIntValues()) {
setPersistentIValue(valueId, constSymintValue);
}
for (const Node& node : graph_.nodes()) {
if (node.target() == "torch.ops.higher_order.run_const_graph") {
const auto& const_graph =
std::get<std::unique_ptr<Graph>>(node.attributes().at(0).value);
for (size_t i = 0; i < node.outputs().size(); ++i) {
foldedConstIds_[std::string{const_graph->outputs().at(i)->name()}] =
node.outputs()[i]->id();
}
}
}
}
ExecutionFrame::ExecutionFrame(const Graph& graph, const Weights& weights)
: ExecutionFrame(graph) {
setWeights(weights);
}
void ExecutionFrame::setWeights(const Weights& weights) {
weightVersion_ = weights.version();
const auto& inputsToWeights = graph_.signature().inputsToWeights();
for (const auto& [inputName, weightName] : inputsToWeights) {
const Value* value = graph_.getValue(inputName);
setPersistentIValue(value->id(), weights.at(weightName));
}
const auto& inputsToCustomObjs = graph_.signature().inputsToCustomObjs();
for (const auto& [inputName, customObjName] : inputsToCustomObjs) {
const Value* value = graph_.getValue(inputName);
setPersistentIValue(value->id(), weights.getCustomObj(customObjName));
}
for (const auto& [value, tensor] : weights.getFoldedConsts()) {
setPersistentIValue(foldedConstIds_.at(value), tensor);
}
for (const auto& [n, iv] : weights.getConstFoldedValues()) {
const Value* v = graph_.getValue(n);
setPersistentIValue(v->id(), iv);
}
updateMovableOutputs();
}
void ExecutionFrame::updateMovableOutputs() {
moveable_output_mask_.assign(moveable_output_mask_.size(), true);
c10::FastSet<ValueId> inputs;
for (const auto* input : graph_.userInputs()) {
if (input) {
inputs.insert(input->id());
}
}
const auto& outputs = graph_.userOutputs();
const size_t num_outputs = outputs.size();
c10::FastSet<ValueId> seen;
for (size_t i = 0; i < num_outputs; i++) {
auto idx = num_outputs - 1 - i;
if (const Value* const* valuePtr = std::get_if<Value*>(&outputs[idx]);
valuePtr && *valuePtr) {
auto id = (*valuePtr)->id();
/*
values are not moveable if:
1. they are persistent
2. they are inputs (since inputs are borrowed)
3. the value will be moved in a later (right-more) output
*/
if (!seen.insert(id).second || persistent_[id] ||
inputs.find(id) != inputs.end()) {
moveable_output_mask_[idx] = false;
}
}
}
}
ExecutionFrame::ExecutionFrame(
const Graph& graph,
size_t numValues,
const std::vector<ValueId>&,
const std::vector<ValueId>&)
: graph_(graph) {
allValues_.resize(numValues);
}
void ExecutionFrame::setIValue(ValueId id, c10::IValue ivalue) {
DCHECK(static_cast<size_t>(id) < allValues_.size());
allValues_[id] = std::move(ivalue);
}
void ExecutionFrame::setBorrowedIValue(ValueId id, c10::IValue ivalue) {
DCHECK(static_cast<size_t>(id) < allValues_.size());
borrowedValueIds_.push_back(id);
allValues_[id] = std::move(ivalue);
}
at::Tensor ExecutionFrame::getTensor(ValueId id) const {
const auto& ivalue = getIValue(id);
if (C10_LIKELY(ivalue.isTensor())) {
return ivalue.toTensor();
} else {
throw std::runtime_error("getTensor called on non-tensor value");
}
}
std::vector<c10::IValue> ExecutionFrame::tryMoveUserOutputs() {
std::vector<c10::IValue> ret;
const auto& outputs = graph_.userOutputs();
ret.reserve(outputs.size());
for (const auto& [i, outputValue] : c10::enumerate(outputs)) {
if (const Value* const* valuePtr = std::get_if<Value*>(&outputValue);
valuePtr && *valuePtr) {
ret.push_back(
isOutputMovable(i) ? moveIValue((*valuePtr)->id())
: getIValue((*valuePtr)->id()));
} else if (Constant const* constant = std::get_if<Constant>(&outputValue)) {
ret.push_back(constantToIValue(*constant));
}
}
return ret;
}
} // namespace torch::nativert

View File

@ -0,0 +1,139 @@
#pragma once
#include <unordered_map>
#include <torch/csrc/distributed/c10d/Work.hpp>
#include <torch/nativert/executor/Weights.h>
#include <torch/nativert/graph/Graph.h>
#include <c10/util/Logging.h>
namespace torch::nativert {
/**
* This class encapsulate the stateful values of an execution,
* most notably, the tensor values passed between nodes, aka intermediate
* activations.
*/
class ExecutionFrame {
public:
// Constructor for weight-less graph, used for higher order ops, e.g.
// torch.cond
explicit ExecutionFrame(const Graph& graph);
explicit ExecutionFrame(const Graph& graph, const Weights& weights);
// Constructor for testing purpose
explicit ExecutionFrame(
const Graph& graph,
size_t numValues,
const std::vector<ValueId>& graphInputIds,
const std::vector<ValueId>& graphOutputIds);
~ExecutionFrame() {
destroyBorrowedIValues();
}
std::vector<c10::IValue> tryMoveUserOutputs();
c10::IValue moveIValue(ValueId id) {
return std::move(allValues_[id]);
}
const c10::IValue& getIValue(ValueId id, bool allowNone = true) const {
const auto& iValue = allValues_[id];
if (allowNone && iValue.isNone()) {
return iValue;
}
DCHECK(!iValue.isNone());
return iValue;
}
c10::IValue& getIValue(ValueId id, bool allowNone = true) {
auto& iValue = allValues_[id];
if (allowNone && iValue.isNone()) {
return iValue;
}
DCHECK(!iValue.isNone());
return iValue;
}
void setIValue(ValueId id, c10::IValue ivalue);
void setBorrowedIValue(ValueId id, c10::IValue ivalue);
at::Tensor getTensor(ValueId id) const;
std::vector<at::Tensor> getTensorVector(ValueId id) const {
return getIValue(id).toTensorVector();
}
int64_t getSymInt(ValueId id) const {
return getIValue(id).toInt();
}
double getSymFloat(ValueId id) const {
return getIValue(id).toDouble();
}
void setPersistentIValue(ValueId id, c10::IValue ivalue) {
setIValue(id, std::move(ivalue));
persistent_[id] = true;
}
void releaseValue(ValueId id) {
CHECK(!persistent_[id]) << "Cannot release persistent value";
allValues_[id] = c10::IValue();
}
void destroyBorrowedIValues() {
for (const auto& id : borrowedValueIds_) {
c10::MaybeOwnedTraits<c10::IValue>::destroyBorrow(getIValue(id));
}
borrowedValueIds_.clear();
}
void setWork(int64_t workId, const c10::intrusive_ptr<c10d::Work>& work) {
work_[workId] = work;
}
c10::intrusive_ptr<c10d::Work> getWork(int64_t workId) const {
CHECK(work_.find(workId) != work_.end())
<< "Couldn't find work with Id: " << workId;
return work_.at(workId);
}
WeightVersion weightVersion() const {
return weightVersion_;
}
void setWeights(const Weights& weights);
private:
bool isOutputMovable(size_t idx) const {
TORCH_CHECK_LT(idx, moveable_output_mask_.size());
return moveable_output_mask_[idx];
}
void updateMovableOutputs();
const Graph& graph_;
WeightVersion weightVersion_ = -1;
// All the intermediate values for the entire graph, including graph inputs
// and outputs This table is fixed once constructed
std::vector<c10::IValue> allValues_;
std::vector<bool> persistent_;
std::unordered_map<int64_t, c10::intrusive_ptr<c10d::Work>> work_;
std::vector<ValueId> borrowedValueIds_;
std::unordered_map<std::string, ValueId> foldedConstIds_;
// moveable_output_mask_[i] corresponds to user_outputs_[i]
//
// if moveable_output_mask_[i] is true, then user_outputs_[i]
// can be moved
std::vector<bool> moveable_output_mask_;
};
} // namespace torch::nativert