mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Reland D78841818 (#159216)
Summary: Relanding D78841818 with fixes Test Plan: Tested all failing tests buck build --config fbcode.use_link_groups=true --flagfile fbcode//mode/dev-nosan fbcode//sigmoid/core/executor/memory/test:layout_planner_tests buck test 'fbcode//mode/opt' fbcode//sigmoid/inference/test:test_passes Rollback Plan: Reviewed By: hl475 Differential Revision: D79038615 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159216 Approved by: https://github.com/dolpm
This commit is contained in:
parent
799303f655
commit
1abff80fae
|
|
@ -43,8 +43,8 @@ class AliasAnalyzerTests : public testing::Test {
|
|||
cfg.enableStaticCPUKernels = true;
|
||||
|
||||
auto graph = stringToGraph(model);
|
||||
auto kernels = KernelFactory().initializeNodeKernels(
|
||||
*graph, nullptr, cfg, {}, nullptr);
|
||||
auto kernels =
|
||||
KernelFactory().initializeNodeKernels(*graph, nullptr, cfg, nullptr);
|
||||
auto kernelSchemas = Executor::getKernelSchemas(kernels.nodeKernels);
|
||||
|
||||
AliasAnalyzer analyzer(*graph, kernelSchemas);
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ return(%o2, %baz)
|
|||
};
|
||||
TEST_F(WeightsTest, ConstructEmptyStateDict) {
|
||||
std::unordered_map<std::string, c10::IValue> stateDict;
|
||||
Weights weights(graph.get(), stateDict, *placement);
|
||||
Weights weights(graph.get(), stateDict);
|
||||
// Check that weights are initialized correctly
|
||||
EXPECT_TRUE(weights.parameters().empty());
|
||||
EXPECT_TRUE(weights.buffers().empty());
|
||||
|
|
@ -33,7 +33,7 @@ TEST_F(WeightsTest, ConstructEmptyStateDict) {
|
|||
}
|
||||
TEST_F(WeightsTest, SetAndGetValue) {
|
||||
std::unordered_map<std::string, c10::IValue> stateDict;
|
||||
Weights weights(graph.get(), stateDict, *placement);
|
||||
Weights weights(graph.get(), stateDict);
|
||||
at::Tensor tensor = at::ones({2, 2});
|
||||
weights.setValue("added_weight", tensor);
|
||||
EXPECT_TRUE(weights.contains("added_weight"));
|
||||
|
|
|
|||
|
|
@ -20,12 +20,10 @@ Executor::Executor(
|
|||
torch::nativert::ExecutorConfig executorConfig,
|
||||
std::shared_ptr<Graph> graph,
|
||||
const std::shared_ptr<Weights>& weights,
|
||||
Placement placement,
|
||||
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
|
||||
pytorchStreamReader)
|
||||
: executorConfig_(std::move(executorConfig)),
|
||||
graph_(std::move(graph)),
|
||||
placement_(std::move(placement)),
|
||||
constantFolder_(
|
||||
executorConfig_.runConstFolding
|
||||
? std::optional<ConstantFolder>(*graph_)
|
||||
|
|
@ -46,7 +44,7 @@ void Executor::initialize(
|
|||
auto start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
auto executionKernels = KernelFactory().initializeNodeKernels(
|
||||
*graph_, weights, executorConfig_, placement_, pytorchStreamReader);
|
||||
*graph_, weights, executorConfig_, pytorchStreamReader);
|
||||
|
||||
if (constantFolder_.has_value()) {
|
||||
constantFolder_->unlinkConstants(executionKernels.nodeKernels);
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@
|
|||
#include <torch/nativert/executor/ExecutionPlanner.h>
|
||||
#include <torch/nativert/executor/ExecutorConfig.h>
|
||||
#include <torch/nativert/executor/GraphExecutorBase.h>
|
||||
#include <torch/nativert/executor/Placement.h>
|
||||
#include <torch/nativert/executor/memory/FunctionSchema.h>
|
||||
#include <torch/nativert/executor/memory/LayoutPlanner.h>
|
||||
#include <torch/nativert/graph/Graph.h>
|
||||
|
|
@ -80,7 +79,6 @@ class Executor {
|
|||
torch::nativert::ExecutorConfig executorConfig,
|
||||
std::shared_ptr<Graph> graph,
|
||||
const std::shared_ptr<Weights>& weights,
|
||||
Placement placement = Placement(),
|
||||
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
|
||||
pytorchStreamReader = nullptr);
|
||||
|
||||
|
|
@ -180,8 +178,6 @@ class Executor {
|
|||
|
||||
std::unique_ptr<GraphExecutorBase> graphExecutor_;
|
||||
|
||||
const Placement placement_;
|
||||
|
||||
// NOTE: delegateExecutors_ is used by nodeKernels_ inside graphExecutor_.
|
||||
std::vector<std::unique_ptr<DelegateExecutor>> delegateExecutors_;
|
||||
|
||||
|
|
|
|||
|
|
@ -25,11 +25,9 @@ WeightVersion Weights::globalVersion_ = 0;
|
|||
Weights::Weights(
|
||||
const Graph* graph,
|
||||
const std::optional<std::unordered_map<std::string, c10::IValue>>&
|
||||
stateDict,
|
||||
Placement placement)
|
||||
stateDict)
|
||||
: graph_(graph),
|
||||
weightsMeta_(graph->weightsMeta()),
|
||||
placement_(std::move(placement)),
|
||||
version_(globalVersion_++) {
|
||||
if (stateDict.has_value()) {
|
||||
loadStateDict(stateDict.value());
|
||||
|
|
@ -43,12 +41,10 @@ Weights::Weights(
|
|||
std::string_view stateDictPathPrefix,
|
||||
const std::unordered_map<std::string, std::string>& constantPaths,
|
||||
std::string_view constantPathPrefix,
|
||||
Placement placement,
|
||||
std::function<bool(const std::string&)> skipSizeCheck,
|
||||
std::function<bool(const std::string&)> skipDtypeCheck)
|
||||
: graph_(graph),
|
||||
weightsMeta_(graph->weightsMeta()),
|
||||
placement_(std::move(placement)),
|
||||
version_(globalVersion_++),
|
||||
skipSizeCheck_(std::move(skipSizeCheck)),
|
||||
skipDtypeCheck_(std::move(skipDtypeCheck)) {
|
||||
|
|
@ -97,7 +93,7 @@ Weights::Weights(
|
|||
|
||||
if (!isUsed) {
|
||||
VLOG(1) << "Tensor " << tensorName << " is not used during inference";
|
||||
auto targetDevice = placement_.getMappedDevice(tensorMeta->device());
|
||||
auto targetDevice = tensorMeta->device();
|
||||
allValues_[tensorName] =
|
||||
at::scalar_tensor(0, at::TensorOptions().device(targetDevice));
|
||||
return;
|
||||
|
|
@ -120,7 +116,7 @@ Weights::Weights(
|
|||
at::empty({0}, tensorOptions)
|
||||
.set_(storage, 0, tensorMeta->sizes(), tensorMeta->strides());
|
||||
|
||||
auto targetDevice = placement_.getMappedDevice(tensorMeta->device());
|
||||
auto targetDevice = tensorMeta->device();
|
||||
VLOG(1) << "Loading weight " << tensorName << " on " << targetDevice;
|
||||
if (!isSameDevice(targetDevice, tensor.device())) {
|
||||
tensor = tensor.to(targetDevice);
|
||||
|
|
@ -308,7 +304,7 @@ void Weights::loadStateDict(
|
|||
TORCH_CHECK(
|
||||
it != weightsMeta_.end(), "Couldn't find ", name, " in weightsMeta");
|
||||
|
||||
auto targetDevice = placement_.getMappedDevice(it->second.device());
|
||||
auto targetDevice = it->second.device();
|
||||
auto tensor = stateDictIt->second.toTensor().to(targetDevice);
|
||||
|
||||
TORCH_CHECK(tensor.sizes() == it->second.sizes());
|
||||
|
|
@ -351,7 +347,7 @@ void Weights::validateValue(const std::string& name, const at::Tensor& newValue)
|
|||
" vs ",
|
||||
newValue.dtype());
|
||||
|
||||
auto targetDevice = placement_.getMappedDevice(weightMeta.device());
|
||||
auto targetDevice = weightMeta.device();
|
||||
if (targetDevice.is_cpu() && targetDevice.has_index()) {
|
||||
LOG(WARNING) << "Target device is cpu but has index: " << targetDevice;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
#include <c10/util/FbcodeMaps.h>
|
||||
#include <c10/util/Logging.h>
|
||||
#include <caffe2/serialize/inline_container.h>
|
||||
#include <torch/nativert/executor/Placement.h>
|
||||
|
||||
#include <torch/nativert/graph/Graph.h>
|
||||
|
||||
|
|
@ -24,8 +23,7 @@ class Weights {
|
|||
explicit Weights(
|
||||
const Graph* graph,
|
||||
const std::optional<std::unordered_map<std::string, c10::IValue>>&
|
||||
stateDict = std::nullopt,
|
||||
Placement placement = Placement());
|
||||
stateDict = std::nullopt);
|
||||
|
||||
// Arguments
|
||||
// - pytorchStreamReader: the reader for the model archive
|
||||
|
|
@ -36,8 +34,6 @@ class Weights {
|
|||
// - constantPaths: a map from constant name to file path in the archive
|
||||
// - constantPathPrefix: a prefix that will be prepended to paths in
|
||||
// constantPathPrefix
|
||||
// - placement: the device placement of the weights, default to follow the
|
||||
// original device in the weight's metadata
|
||||
explicit Weights(
|
||||
const Graph* graph,
|
||||
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
|
||||
|
|
@ -46,7 +42,6 @@ class Weights {
|
|||
std::string_view stateDictPathPrefix,
|
||||
const std::unordered_map<std::string, std::string>& constantPaths,
|
||||
std::string_view constantPathPrefix,
|
||||
Placement placement = Placement(),
|
||||
std::function<bool(const std::string&)> skipSizeCheck = {},
|
||||
std::function<bool(const std::string&)> skipDtypeCheck = {});
|
||||
|
||||
|
|
@ -107,7 +102,6 @@ class Weights {
|
|||
private:
|
||||
const Graph* graph_;
|
||||
const std::unordered_map<std::string, TensorMeta>& weightsMeta_;
|
||||
Placement placement_;
|
||||
|
||||
// keys are parameter/buffer/constant names, not graph input names!
|
||||
std::unordered_map<std::string, at::Tensor> allValues_;
|
||||
|
|
|
|||
|
|
@ -661,10 +661,26 @@ void Graph::replaceAllUsesAfterNode(
|
|||
}
|
||||
|
||||
void Graph::applyDevicePlacement(const Placement& placement) {
|
||||
// TODO: consolidate device info in weight loading here as well.
|
||||
TORCH_CHECK(
|
||||
!placementApplied_,
|
||||
"placement has been applied to the graph! placement must be applied once and once only.");
|
||||
|
||||
placementApplied_ = true;
|
||||
|
||||
// inplace override node's device-typed attributes according to placement
|
||||
for (auto& node : nodes_) {
|
||||
node.applyDevicePlacement(placement);
|
||||
}
|
||||
|
||||
// inplace override weightMeta_'s device according to placement
|
||||
for (auto& [_, weightMeta] : weightsMeta_) {
|
||||
weightMeta.applyDevicePlacement(placement);
|
||||
}
|
||||
|
||||
// inplace override tensorValuesMeta_'s device according to placement
|
||||
for (auto& [_, tensorMeta] : tensorValuesMeta_) {
|
||||
tensorMeta.applyDevicePlacement(placement);
|
||||
}
|
||||
}
|
||||
|
||||
Node* Graph::nodeAfter(Node* n) {
|
||||
|
|
|
|||
|
|
@ -584,6 +584,8 @@ class Graph {
|
|||
void setWeightsMeta(
|
||||
const std::unordered_map<std::string, torch::_export::TensorMeta>&
|
||||
tensorsMeta) {
|
||||
TORCH_CHECK(!placementApplied_);
|
||||
|
||||
for (auto [name, tensorMeta] : tensorsMeta) {
|
||||
weightsMeta_.emplace(name, TensorMeta{tensorMeta});
|
||||
}
|
||||
|
|
@ -605,6 +607,8 @@ class Graph {
|
|||
void setTensorValuesMeta(
|
||||
const std::unordered_map<std::string, torch::_export::TensorMeta>&
|
||||
tensorsMeta) {
|
||||
TORCH_CHECK(!placementApplied_);
|
||||
|
||||
for (auto [name, tensorMeta] : tensorsMeta) {
|
||||
tensorValuesMeta_.emplace(name, TensorMeta{tensorMeta});
|
||||
}
|
||||
|
|
@ -630,6 +634,8 @@ class Graph {
|
|||
friend std::ostream& operator<<(std::ostream& out, const Graph& g);
|
||||
GraphSignature signature_;
|
||||
|
||||
bool placementApplied_ = false;
|
||||
|
||||
// keys are parameters, buffers, tensor_constants' names
|
||||
std::unordered_map<std::string, TensorMeta> weightsMeta_;
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
#include <c10/util/ArrayRef.h>
|
||||
|
||||
#include <torch/csrc/utils/generated_serialization_types.h>
|
||||
#include <torch/nativert/executor/Placement.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
|
|
@ -68,6 +69,11 @@ class TensorMeta {
|
|||
requiresGrad_);
|
||||
}
|
||||
|
||||
// override device according to placement
|
||||
void applyDevicePlacement(const Placement& placement) {
|
||||
device_ = placement.getMappedDevice(device_);
|
||||
}
|
||||
|
||||
// NYI
|
||||
// c10::SymIntArrayRef sym_sizes() const {}
|
||||
// c10::SymIntArrayRef sym_strides() const {}
|
||||
|
|
|
|||
|
|
@ -22,8 +22,7 @@ namespace {
|
|||
c10::Device inferTargetDevice(
|
||||
const Node& node,
|
||||
const std::unordered_map<std::string, torch::nativert::TensorMeta>&
|
||||
tensorValuesMeta,
|
||||
const Placement& placement) {
|
||||
tensorValuesMeta) {
|
||||
if (node.target() == "prim.Input" || node.target() == "prim.Output") {
|
||||
return c10::Device(c10::DeviceType::CPU);
|
||||
}
|
||||
|
|
@ -56,7 +55,7 @@ c10::Device inferTargetDevice(
|
|||
}
|
||||
}
|
||||
|
||||
return placement.getMappedDevice(devices[0]);
|
||||
return devices[0];
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -126,7 +125,6 @@ ExecutionKernels KernelFactory::initializeNodeKernels(
|
|||
const Graph& graph,
|
||||
const std::shared_ptr<Weights>& weights,
|
||||
const torch::nativert::ExecutorConfig& executorConfig,
|
||||
const Placement& placement,
|
||||
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
|
||||
pytorchStreamReader) {
|
||||
std::vector<std::unique_ptr<OpKernel>> nodeKernels;
|
||||
|
|
@ -146,7 +144,7 @@ ExecutionKernels KernelFactory::initializeNodeKernels(
|
|||
std::string target = std::string(node.target());
|
||||
|
||||
c10::Device targetDevice =
|
||||
inferTargetDevice(node, graph.tensorValuesMeta(), placement);
|
||||
inferTargetDevice(node, graph.tensorValuesMeta());
|
||||
|
||||
bool matched = false;
|
||||
for (const auto& [_, handler] : handlers) {
|
||||
|
|
@ -212,8 +210,8 @@ ExecutionKernels KernelFactory::initializeNodeKernels(
|
|||
for (const auto& attr : node.attributes()) {
|
||||
if (std::holds_alternative<std::unique_ptr<Graph>>(attr.value)) {
|
||||
const auto& subgraph = std::get<std::unique_ptr<Graph>>(attr.value);
|
||||
auto executionKernels = initializeNodeKernels(
|
||||
*subgraph, weights, executorConfig, placement);
|
||||
auto executionKernels =
|
||||
initializeNodeKernels(*subgraph, weights, executorConfig);
|
||||
TORCH_CHECK(
|
||||
executionKernels.delegateExecutors.empty(),
|
||||
"HigherOrderKernel does not support delegates");
|
||||
|
|
|
|||
|
|
@ -76,7 +76,6 @@ class KernelFactory {
|
|||
const Graph& graph,
|
||||
const std::shared_ptr<Weights>& weights,
|
||||
const torch::nativert::ExecutorConfig& executorConfig,
|
||||
const Placement& placement,
|
||||
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
|
||||
pytorchStreamReader = nullptr);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user