diff --git a/test/cpp/nativert/test_alias_analyzer.cpp b/test/cpp/nativert/test_alias_analyzer.cpp index afa469f58c8..88357938cdd 100644 --- a/test/cpp/nativert/test_alias_analyzer.cpp +++ b/test/cpp/nativert/test_alias_analyzer.cpp @@ -43,8 +43,8 @@ class AliasAnalyzerTests : public testing::Test { cfg.enableStaticCPUKernels = true; auto graph = stringToGraph(model); - auto kernels = KernelFactory().initializeNodeKernels( - *graph, nullptr, cfg, {}, nullptr); + auto kernels = + KernelFactory().initializeNodeKernels(*graph, nullptr, cfg, nullptr); auto kernelSchemas = Executor::getKernelSchemas(kernels.nodeKernels); AliasAnalyzer analyzer(*graph, kernelSchemas); diff --git a/test/cpp/nativert/test_weights.cpp b/test/cpp/nativert/test_weights.cpp index 43d05d5ad88..566bc046987 100644 --- a/test/cpp/nativert/test_weights.cpp +++ b/test/cpp/nativert/test_weights.cpp @@ -25,7 +25,7 @@ return(%o2, %baz) }; TEST_F(WeightsTest, ConstructEmptyStateDict) { std::unordered_map stateDict; - Weights weights(graph.get(), stateDict, *placement); + Weights weights(graph.get(), stateDict); // Check that weights are initialized correctly EXPECT_TRUE(weights.parameters().empty()); EXPECT_TRUE(weights.buffers().empty()); @@ -33,7 +33,7 @@ TEST_F(WeightsTest, ConstructEmptyStateDict) { } TEST_F(WeightsTest, SetAndGetValue) { std::unordered_map stateDict; - Weights weights(graph.get(), stateDict, *placement); + Weights weights(graph.get(), stateDict); at::Tensor tensor = at::ones({2, 2}); weights.setValue("added_weight", tensor); EXPECT_TRUE(weights.contains("added_weight")); diff --git a/torch/nativert/executor/Executor.cpp b/torch/nativert/executor/Executor.cpp index 3a3f3d33513..932972ae2b5 100644 --- a/torch/nativert/executor/Executor.cpp +++ b/torch/nativert/executor/Executor.cpp @@ -20,12 +20,10 @@ Executor::Executor( torch::nativert::ExecutorConfig executorConfig, std::shared_ptr graph, const std::shared_ptr& weights, - Placement placement, const std::shared_ptr& pytorchStreamReader) : executorConfig_(std::move(executorConfig)), graph_(std::move(graph)), - placement_(std::move(placement)), constantFolder_( executorConfig_.runConstFolding ? std::optional(*graph_) @@ -46,7 +44,7 @@ void Executor::initialize( auto start = std::chrono::high_resolution_clock::now(); auto executionKernels = KernelFactory().initializeNodeKernels( - *graph_, weights, executorConfig_, placement_, pytorchStreamReader); + *graph_, weights, executorConfig_, pytorchStreamReader); if (constantFolder_.has_value()) { constantFolder_->unlinkConstants(executionKernels.nodeKernels); diff --git a/torch/nativert/executor/Executor.h b/torch/nativert/executor/Executor.h index 57356c36d6c..4f40946b4b4 100644 --- a/torch/nativert/executor/Executor.h +++ b/torch/nativert/executor/Executor.h @@ -15,7 +15,6 @@ #include #include #include -#include #include #include #include @@ -80,7 +79,6 @@ class Executor { torch::nativert::ExecutorConfig executorConfig, std::shared_ptr graph, const std::shared_ptr& weights, - Placement placement = Placement(), const std::shared_ptr& pytorchStreamReader = nullptr); @@ -180,8 +178,6 @@ class Executor { std::unique_ptr graphExecutor_; - const Placement placement_; - // NOTE: delegateExecutors_ is used by nodeKernels_ inside graphExecutor_. std::vector> delegateExecutors_; diff --git a/torch/nativert/executor/Weights.cpp b/torch/nativert/executor/Weights.cpp index 1c14b79e6d9..918b532160c 100644 --- a/torch/nativert/executor/Weights.cpp +++ b/torch/nativert/executor/Weights.cpp @@ -25,11 +25,9 @@ WeightVersion Weights::globalVersion_ = 0; Weights::Weights( const Graph* graph, const std::optional>& - stateDict, - Placement placement) + stateDict) : graph_(graph), weightsMeta_(graph->weightsMeta()), - placement_(std::move(placement)), version_(globalVersion_++) { if (stateDict.has_value()) { loadStateDict(stateDict.value()); @@ -43,12 +41,10 @@ Weights::Weights( std::string_view stateDictPathPrefix, const std::unordered_map& constantPaths, std::string_view constantPathPrefix, - Placement placement, std::function skipSizeCheck, std::function skipDtypeCheck) : graph_(graph), weightsMeta_(graph->weightsMeta()), - placement_(std::move(placement)), version_(globalVersion_++), skipSizeCheck_(std::move(skipSizeCheck)), skipDtypeCheck_(std::move(skipDtypeCheck)) { @@ -97,7 +93,7 @@ Weights::Weights( if (!isUsed) { VLOG(1) << "Tensor " << tensorName << " is not used during inference"; - auto targetDevice = placement_.getMappedDevice(tensorMeta->device()); + auto targetDevice = tensorMeta->device(); allValues_[tensorName] = at::scalar_tensor(0, at::TensorOptions().device(targetDevice)); return; @@ -120,7 +116,7 @@ Weights::Weights( at::empty({0}, tensorOptions) .set_(storage, 0, tensorMeta->sizes(), tensorMeta->strides()); - auto targetDevice = placement_.getMappedDevice(tensorMeta->device()); + auto targetDevice = tensorMeta->device(); VLOG(1) << "Loading weight " << tensorName << " on " << targetDevice; if (!isSameDevice(targetDevice, tensor.device())) { tensor = tensor.to(targetDevice); @@ -308,7 +304,7 @@ void Weights::loadStateDict( TORCH_CHECK( it != weightsMeta_.end(), "Couldn't find ", name, " in weightsMeta"); - auto targetDevice = placement_.getMappedDevice(it->second.device()); + auto targetDevice = it->second.device(); auto tensor = stateDictIt->second.toTensor().to(targetDevice); TORCH_CHECK(tensor.sizes() == it->second.sizes()); @@ -351,7 +347,7 @@ void Weights::validateValue(const std::string& name, const at::Tensor& newValue) " vs ", newValue.dtype()); - auto targetDevice = placement_.getMappedDevice(weightMeta.device()); + auto targetDevice = weightMeta.device(); if (targetDevice.is_cpu() && targetDevice.has_index()) { LOG(WARNING) << "Target device is cpu but has index: " << targetDevice; } diff --git a/torch/nativert/executor/Weights.h b/torch/nativert/executor/Weights.h index 5a6778b524f..7108f32bba9 100644 --- a/torch/nativert/executor/Weights.h +++ b/torch/nativert/executor/Weights.h @@ -3,7 +3,6 @@ #include #include #include -#include #include @@ -24,8 +23,7 @@ class Weights { explicit Weights( const Graph* graph, const std::optional>& - stateDict = std::nullopt, - Placement placement = Placement()); + stateDict = std::nullopt); // Arguments // - pytorchStreamReader: the reader for the model archive @@ -36,8 +34,6 @@ class Weights { // - constantPaths: a map from constant name to file path in the archive // - constantPathPrefix: a prefix that will be prepended to paths in // constantPathPrefix - // - placement: the device placement of the weights, default to follow the - // original device in the weight's metadata explicit Weights( const Graph* graph, std::shared_ptr @@ -46,7 +42,6 @@ class Weights { std::string_view stateDictPathPrefix, const std::unordered_map& constantPaths, std::string_view constantPathPrefix, - Placement placement = Placement(), std::function skipSizeCheck = {}, std::function skipDtypeCheck = {}); @@ -107,7 +102,6 @@ class Weights { private: const Graph* graph_; const std::unordered_map& weightsMeta_; - Placement placement_; // keys are parameter/buffer/constant names, not graph input names! std::unordered_map allValues_; diff --git a/torch/nativert/graph/Graph.cpp b/torch/nativert/graph/Graph.cpp index bce01f278a5..ee5fbaca11b 100644 --- a/torch/nativert/graph/Graph.cpp +++ b/torch/nativert/graph/Graph.cpp @@ -661,10 +661,26 @@ void Graph::replaceAllUsesAfterNode( } void Graph::applyDevicePlacement(const Placement& placement) { - // TODO: consolidate device info in weight loading here as well. + TORCH_CHECK( + !placementApplied_, + "placement has been applied to the graph! placement must be applied once and once only."); + + placementApplied_ = true; + + // inplace override node's device-typed attributes according to placement for (auto& node : nodes_) { node.applyDevicePlacement(placement); } + + // inplace override weightMeta_'s device according to placement + for (auto& [_, weightMeta] : weightsMeta_) { + weightMeta.applyDevicePlacement(placement); + } + + // inplace override tensorValuesMeta_'s device according to placement + for (auto& [_, tensorMeta] : tensorValuesMeta_) { + tensorMeta.applyDevicePlacement(placement); + } } Node* Graph::nodeAfter(Node* n) { diff --git a/torch/nativert/graph/Graph.h b/torch/nativert/graph/Graph.h index 7202272a4aa..a86e9736219 100644 --- a/torch/nativert/graph/Graph.h +++ b/torch/nativert/graph/Graph.h @@ -584,6 +584,8 @@ class Graph { void setWeightsMeta( const std::unordered_map& tensorsMeta) { + TORCH_CHECK(!placementApplied_); + for (auto [name, tensorMeta] : tensorsMeta) { weightsMeta_.emplace(name, TensorMeta{tensorMeta}); } @@ -605,6 +607,8 @@ class Graph { void setTensorValuesMeta( const std::unordered_map& tensorsMeta) { + TORCH_CHECK(!placementApplied_); + for (auto [name, tensorMeta] : tensorsMeta) { tensorValuesMeta_.emplace(name, TensorMeta{tensorMeta}); } @@ -630,6 +634,8 @@ class Graph { friend std::ostream& operator<<(std::ostream& out, const Graph& g); GraphSignature signature_; + bool placementApplied_ = false; + // keys are parameters, buffers, tensor_constants' names std::unordered_map weightsMeta_; diff --git a/torch/nativert/graph/TensorMeta.h b/torch/nativert/graph/TensorMeta.h index 585383a95b5..7fe9a88c731 100644 --- a/torch/nativert/graph/TensorMeta.h +++ b/torch/nativert/graph/TensorMeta.h @@ -10,6 +10,7 @@ #include #include +#include namespace torch::nativert { @@ -68,6 +69,11 @@ class TensorMeta { requiresGrad_); } + // override device according to placement + void applyDevicePlacement(const Placement& placement) { + device_ = placement.getMappedDevice(device_); + } + // NYI // c10::SymIntArrayRef sym_sizes() const {} // c10::SymIntArrayRef sym_strides() const {} diff --git a/torch/nativert/kernels/KernelFactory.cpp b/torch/nativert/kernels/KernelFactory.cpp index 7237d39fec9..88740d2aa45 100644 --- a/torch/nativert/kernels/KernelFactory.cpp +++ b/torch/nativert/kernels/KernelFactory.cpp @@ -22,8 +22,7 @@ namespace { c10::Device inferTargetDevice( const Node& node, const std::unordered_map& - tensorValuesMeta, - const Placement& placement) { + tensorValuesMeta) { if (node.target() == "prim.Input" || node.target() == "prim.Output") { return c10::Device(c10::DeviceType::CPU); } @@ -56,7 +55,7 @@ c10::Device inferTargetDevice( } } - return placement.getMappedDevice(devices[0]); + return devices[0]; } } @@ -126,7 +125,6 @@ ExecutionKernels KernelFactory::initializeNodeKernels( const Graph& graph, const std::shared_ptr& weights, const torch::nativert::ExecutorConfig& executorConfig, - const Placement& placement, const std::shared_ptr& pytorchStreamReader) { std::vector> nodeKernels; @@ -146,7 +144,7 @@ ExecutionKernels KernelFactory::initializeNodeKernels( std::string target = std::string(node.target()); c10::Device targetDevice = - inferTargetDevice(node, graph.tensorValuesMeta(), placement); + inferTargetDevice(node, graph.tensorValuesMeta()); bool matched = false; for (const auto& [_, handler] : handlers) { @@ -212,8 +210,8 @@ ExecutionKernels KernelFactory::initializeNodeKernels( for (const auto& attr : node.attributes()) { if (std::holds_alternative>(attr.value)) { const auto& subgraph = std::get>(attr.value); - auto executionKernels = initializeNodeKernels( - *subgraph, weights, executorConfig, placement); + auto executionKernels = + initializeNodeKernels(*subgraph, weights, executorConfig); TORCH_CHECK( executionKernels.delegateExecutors.empty(), "HigherOrderKernel does not support delegates"); diff --git a/torch/nativert/kernels/KernelFactory.h b/torch/nativert/kernels/KernelFactory.h index 3f341f1115d..77308e0055b 100644 --- a/torch/nativert/kernels/KernelFactory.h +++ b/torch/nativert/kernels/KernelFactory.h @@ -76,7 +76,6 @@ class KernelFactory { const Graph& graph, const std::shared_ptr& weights, const torch::nativert::ExecutorConfig& executorConfig, - const Placement& placement, const std::shared_ptr& pytorchStreamReader = nullptr);