Reland D78841818 (#159216)

Summary: Relanding D78841818 with fixes

Test Plan:
Tested all failing tests

buck build --config fbcode.use_link_groups=true --flagfile fbcode//mode/dev-nosan fbcode//sigmoid/core/executor/memory/test:layout_planner_tests

buck test 'fbcode//mode/opt' fbcode//sigmoid/inference/test:test_passes

Rollback Plan:

Reviewed By: hl475

Differential Revision: D79038615

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159216
Approved by: https://github.com/dolpm
This commit is contained in:
Sherlock Huang 2025-07-28 07:39:35 +00:00 committed by PyTorch MergeBot
parent 799303f655
commit 1abff80fae
11 changed files with 45 additions and 36 deletions

View File

@ -43,8 +43,8 @@ class AliasAnalyzerTests : public testing::Test {
cfg.enableStaticCPUKernels = true;
auto graph = stringToGraph(model);
auto kernels = KernelFactory().initializeNodeKernels(
*graph, nullptr, cfg, {}, nullptr);
auto kernels =
KernelFactory().initializeNodeKernels(*graph, nullptr, cfg, nullptr);
auto kernelSchemas = Executor::getKernelSchemas(kernels.nodeKernels);
AliasAnalyzer analyzer(*graph, kernelSchemas);

View File

@ -25,7 +25,7 @@ return(%o2, %baz)
};
TEST_F(WeightsTest, ConstructEmptyStateDict) {
std::unordered_map<std::string, c10::IValue> stateDict;
Weights weights(graph.get(), stateDict, *placement);
Weights weights(graph.get(), stateDict);
// Check that weights are initialized correctly
EXPECT_TRUE(weights.parameters().empty());
EXPECT_TRUE(weights.buffers().empty());
@ -33,7 +33,7 @@ TEST_F(WeightsTest, ConstructEmptyStateDict) {
}
TEST_F(WeightsTest, SetAndGetValue) {
std::unordered_map<std::string, c10::IValue> stateDict;
Weights weights(graph.get(), stateDict, *placement);
Weights weights(graph.get(), stateDict);
at::Tensor tensor = at::ones({2, 2});
weights.setValue("added_weight", tensor);
EXPECT_TRUE(weights.contains("added_weight"));

View File

@ -20,12 +20,10 @@ Executor::Executor(
torch::nativert::ExecutorConfig executorConfig,
std::shared_ptr<Graph> graph,
const std::shared_ptr<Weights>& weights,
Placement placement,
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
pytorchStreamReader)
: executorConfig_(std::move(executorConfig)),
graph_(std::move(graph)),
placement_(std::move(placement)),
constantFolder_(
executorConfig_.runConstFolding
? std::optional<ConstantFolder>(*graph_)
@ -46,7 +44,7 @@ void Executor::initialize(
auto start = std::chrono::high_resolution_clock::now();
auto executionKernels = KernelFactory().initializeNodeKernels(
*graph_, weights, executorConfig_, placement_, pytorchStreamReader);
*graph_, weights, executorConfig_, pytorchStreamReader);
if (constantFolder_.has_value()) {
constantFolder_->unlinkConstants(executionKernels.nodeKernels);

View File

@ -15,7 +15,6 @@
#include <torch/nativert/executor/ExecutionPlanner.h>
#include <torch/nativert/executor/ExecutorConfig.h>
#include <torch/nativert/executor/GraphExecutorBase.h>
#include <torch/nativert/executor/Placement.h>
#include <torch/nativert/executor/memory/FunctionSchema.h>
#include <torch/nativert/executor/memory/LayoutPlanner.h>
#include <torch/nativert/graph/Graph.h>
@ -80,7 +79,6 @@ class Executor {
torch::nativert::ExecutorConfig executorConfig,
std::shared_ptr<Graph> graph,
const std::shared_ptr<Weights>& weights,
Placement placement = Placement(),
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
pytorchStreamReader = nullptr);
@ -180,8 +178,6 @@ class Executor {
std::unique_ptr<GraphExecutorBase> graphExecutor_;
const Placement placement_;
// NOTE: delegateExecutors_ is used by nodeKernels_ inside graphExecutor_.
std::vector<std::unique_ptr<DelegateExecutor>> delegateExecutors_;

View File

@ -25,11 +25,9 @@ WeightVersion Weights::globalVersion_ = 0;
Weights::Weights(
const Graph* graph,
const std::optional<std::unordered_map<std::string, c10::IValue>>&
stateDict,
Placement placement)
stateDict)
: graph_(graph),
weightsMeta_(graph->weightsMeta()),
placement_(std::move(placement)),
version_(globalVersion_++) {
if (stateDict.has_value()) {
loadStateDict(stateDict.value());
@ -43,12 +41,10 @@ Weights::Weights(
std::string_view stateDictPathPrefix,
const std::unordered_map<std::string, std::string>& constantPaths,
std::string_view constantPathPrefix,
Placement placement,
std::function<bool(const std::string&)> skipSizeCheck,
std::function<bool(const std::string&)> skipDtypeCheck)
: graph_(graph),
weightsMeta_(graph->weightsMeta()),
placement_(std::move(placement)),
version_(globalVersion_++),
skipSizeCheck_(std::move(skipSizeCheck)),
skipDtypeCheck_(std::move(skipDtypeCheck)) {
@ -97,7 +93,7 @@ Weights::Weights(
if (!isUsed) {
VLOG(1) << "Tensor " << tensorName << " is not used during inference";
auto targetDevice = placement_.getMappedDevice(tensorMeta->device());
auto targetDevice = tensorMeta->device();
allValues_[tensorName] =
at::scalar_tensor(0, at::TensorOptions().device(targetDevice));
return;
@ -120,7 +116,7 @@ Weights::Weights(
at::empty({0}, tensorOptions)
.set_(storage, 0, tensorMeta->sizes(), tensorMeta->strides());
auto targetDevice = placement_.getMappedDevice(tensorMeta->device());
auto targetDevice = tensorMeta->device();
VLOG(1) << "Loading weight " << tensorName << " on " << targetDevice;
if (!isSameDevice(targetDevice, tensor.device())) {
tensor = tensor.to(targetDevice);
@ -308,7 +304,7 @@ void Weights::loadStateDict(
TORCH_CHECK(
it != weightsMeta_.end(), "Couldn't find ", name, " in weightsMeta");
auto targetDevice = placement_.getMappedDevice(it->second.device());
auto targetDevice = it->second.device();
auto tensor = stateDictIt->second.toTensor().to(targetDevice);
TORCH_CHECK(tensor.sizes() == it->second.sizes());
@ -351,7 +347,7 @@ void Weights::validateValue(const std::string& name, const at::Tensor& newValue)
" vs ",
newValue.dtype());
auto targetDevice = placement_.getMappedDevice(weightMeta.device());
auto targetDevice = weightMeta.device();
if (targetDevice.is_cpu() && targetDevice.has_index()) {
LOG(WARNING) << "Target device is cpu but has index: " << targetDevice;
}

View File

@ -3,7 +3,6 @@
#include <c10/util/FbcodeMaps.h>
#include <c10/util/Logging.h>
#include <caffe2/serialize/inline_container.h>
#include <torch/nativert/executor/Placement.h>
#include <torch/nativert/graph/Graph.h>
@ -24,8 +23,7 @@ class Weights {
explicit Weights(
const Graph* graph,
const std::optional<std::unordered_map<std::string, c10::IValue>>&
stateDict = std::nullopt,
Placement placement = Placement());
stateDict = std::nullopt);
// Arguments
// - pytorchStreamReader: the reader for the model archive
@ -36,8 +34,6 @@ class Weights {
// - constantPaths: a map from constant name to file path in the archive
// - constantPathPrefix: a prefix that will be prepended to paths in
// constantPathPrefix
// - placement: the device placement of the weights, default to follow the
// original device in the weight's metadata
explicit Weights(
const Graph* graph,
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
@ -46,7 +42,6 @@ class Weights {
std::string_view stateDictPathPrefix,
const std::unordered_map<std::string, std::string>& constantPaths,
std::string_view constantPathPrefix,
Placement placement = Placement(),
std::function<bool(const std::string&)> skipSizeCheck = {},
std::function<bool(const std::string&)> skipDtypeCheck = {});
@ -107,7 +102,6 @@ class Weights {
private:
const Graph* graph_;
const std::unordered_map<std::string, TensorMeta>& weightsMeta_;
Placement placement_;
// keys are parameter/buffer/constant names, not graph input names!
std::unordered_map<std::string, at::Tensor> allValues_;

View File

@ -661,10 +661,26 @@ void Graph::replaceAllUsesAfterNode(
}
void Graph::applyDevicePlacement(const Placement& placement) {
// TODO: consolidate device info in weight loading here as well.
TORCH_CHECK(
!placementApplied_,
"placement has been applied to the graph! placement must be applied once and once only.");
placementApplied_ = true;
// inplace override node's device-typed attributes according to placement
for (auto& node : nodes_) {
node.applyDevicePlacement(placement);
}
// inplace override weightMeta_'s device according to placement
for (auto& [_, weightMeta] : weightsMeta_) {
weightMeta.applyDevicePlacement(placement);
}
// inplace override tensorValuesMeta_'s device according to placement
for (auto& [_, tensorMeta] : tensorValuesMeta_) {
tensorMeta.applyDevicePlacement(placement);
}
}
Node* Graph::nodeAfter(Node* n) {

View File

@ -584,6 +584,8 @@ class Graph {
void setWeightsMeta(
const std::unordered_map<std::string, torch::_export::TensorMeta>&
tensorsMeta) {
TORCH_CHECK(!placementApplied_);
for (auto [name, tensorMeta] : tensorsMeta) {
weightsMeta_.emplace(name, TensorMeta{tensorMeta});
}
@ -605,6 +607,8 @@ class Graph {
void setTensorValuesMeta(
const std::unordered_map<std::string, torch::_export::TensorMeta>&
tensorsMeta) {
TORCH_CHECK(!placementApplied_);
for (auto [name, tensorMeta] : tensorsMeta) {
tensorValuesMeta_.emplace(name, TensorMeta{tensorMeta});
}
@ -630,6 +634,8 @@ class Graph {
friend std::ostream& operator<<(std::ostream& out, const Graph& g);
GraphSignature signature_;
bool placementApplied_ = false;
// keys are parameters, buffers, tensor_constants' names
std::unordered_map<std::string, TensorMeta> weightsMeta_;

View File

@ -10,6 +10,7 @@
#include <c10/util/ArrayRef.h>
#include <torch/csrc/utils/generated_serialization_types.h>
#include <torch/nativert/executor/Placement.h>
namespace torch::nativert {
@ -68,6 +69,11 @@ class TensorMeta {
requiresGrad_);
}
// override device according to placement
void applyDevicePlacement(const Placement& placement) {
device_ = placement.getMappedDevice(device_);
}
// NYI
// c10::SymIntArrayRef sym_sizes() const {}
// c10::SymIntArrayRef sym_strides() const {}

View File

@ -22,8 +22,7 @@ namespace {
c10::Device inferTargetDevice(
const Node& node,
const std::unordered_map<std::string, torch::nativert::TensorMeta>&
tensorValuesMeta,
const Placement& placement) {
tensorValuesMeta) {
if (node.target() == "prim.Input" || node.target() == "prim.Output") {
return c10::Device(c10::DeviceType::CPU);
}
@ -56,7 +55,7 @@ c10::Device inferTargetDevice(
}
}
return placement.getMappedDevice(devices[0]);
return devices[0];
}
}
@ -126,7 +125,6 @@ ExecutionKernels KernelFactory::initializeNodeKernels(
const Graph& graph,
const std::shared_ptr<Weights>& weights,
const torch::nativert::ExecutorConfig& executorConfig,
const Placement& placement,
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
pytorchStreamReader) {
std::vector<std::unique_ptr<OpKernel>> nodeKernels;
@ -146,7 +144,7 @@ ExecutionKernels KernelFactory::initializeNodeKernels(
std::string target = std::string(node.target());
c10::Device targetDevice =
inferTargetDevice(node, graph.tensorValuesMeta(), placement);
inferTargetDevice(node, graph.tensorValuesMeta());
bool matched = false;
for (const auto& [_, handler] : handlers) {
@ -212,8 +210,8 @@ ExecutionKernels KernelFactory::initializeNodeKernels(
for (const auto& attr : node.attributes()) {
if (std::holds_alternative<std::unique_ptr<Graph>>(attr.value)) {
const auto& subgraph = std::get<std::unique_ptr<Graph>>(attr.value);
auto executionKernels = initializeNodeKernels(
*subgraph, weights, executorConfig, placement);
auto executionKernels =
initializeNodeKernels(*subgraph, weights, executorConfig);
TORCH_CHECK(
executionKernels.delegateExecutors.empty(),
"HigherOrderKernel does not support delegates");

View File

@ -76,7 +76,6 @@ class KernelFactory {
const Graph& graph,
const std::shared_ptr<Weights>& weights,
const torch::nativert::ExecutorConfig& executorConfig,
const Placement& placement,
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
pytorchStreamReader = nullptr);