diff --git a/test/cpp/nativert/CMakeLists.txt b/test/cpp/nativert/CMakeLists.txt index 81ca5869774..10b750d8b39 100644 --- a/test/cpp/nativert/CMakeLists.txt +++ b/test/cpp/nativert/CMakeLists.txt @@ -21,6 +21,9 @@ set(NATIVERT_TEST_SRCS ${TORCH_ROOT}/torch/nativert/executor/memory/GreedyBySize.cpp ${TORCH_ROOT}/torch/nativert/executor/memory/Bump.cpp ${TORCH_ROOT}/torch/nativert/executor/memory/DisjointStorageGroups.cpp + ${TORCH_ROOT}/torch/nativert/executor/memory/LayoutPlanner.cpp + ${TORCH_ROOT}/torch/nativert/executor/memory/LayoutManager.cpp + ${TORCH_ROOT}/torch/nativert/executor/memory/AliasAnalyzer.cpp ) add_executable(test_nativert diff --git a/test/cpp/nativert/test_execution_frame.cpp b/test/cpp/nativert/test_execution_frame.cpp index ab5fa2e146d..1f4a6975ad9 100644 --- a/test/cpp/nativert/test_execution_frame.cpp +++ b/test/cpp/nativert/test_execution_frame.cpp @@ -90,7 +90,9 @@ TEST(ExecutionFrameTest, TestPersistentValue) { auto wid = graph->getValue("my_weight")->id(); EXPECT_NO_THROW(frame.getTensor(wid)); - EXPECT_DEATH(frame.releaseValue(wid), "Cannot release persistent value"); + // can't release persistent value + frame.releaseValueIfNeeded(wid); + EXPECT_FALSE(frame.getIValue(wid).isNone()); } } // namespace torch::nativert diff --git a/torch/nativert/executor/ExecutionFrame.cpp b/torch/nativert/executor/ExecutionFrame.cpp index 2aa11e6eaba..40b8cfc6e90 100644 --- a/torch/nativert/executor/ExecutionFrame.cpp +++ b/torch/nativert/executor/ExecutionFrame.cpp @@ -29,9 +29,20 @@ ExecutionFrame::ExecutionFrame(const Graph& graph) } } -ExecutionFrame::ExecutionFrame(const Graph& graph, const Weights& weights) +ExecutionFrame::ExecutionFrame( + const Graph& graph, + const Weights& weights, + const torch::nativert::ExecutorConfig& cfg, + LayoutPlanner* layoutPlanner) : ExecutionFrame(graph) { setWeights(weights); + if (layoutPlanner != nullptr) { + layoutPlanner_ = layoutPlanner; + layoutManager_ = std::make_unique( + *layoutPlanner, + *this, + cfg.layoutPlannerSettings.layoutManagerSettings()); + } } void ExecutionFrame::setWeights(const Weights& weights) { diff --git a/torch/nativert/executor/ExecutionFrame.h b/torch/nativert/executor/ExecutionFrame.h index 725f2fd0a8c..d7b2ca232cc 100644 --- a/torch/nativert/executor/ExecutionFrame.h +++ b/torch/nativert/executor/ExecutionFrame.h @@ -3,7 +3,9 @@ #include #include +#include #include +#include #include #include @@ -21,7 +23,11 @@ class ExecutionFrame { // torch.cond explicit ExecutionFrame(const Graph& graph); - explicit ExecutionFrame(const Graph& graph, const Weights& weights); + explicit ExecutionFrame( + const Graph& graph, + const Weights& weights, + const torch::nativert::ExecutorConfig& executorConfig = {}, + LayoutPlanner* layoutPlanner = nullptr); // Constructor for testing purpose explicit ExecutionFrame( @@ -34,6 +40,16 @@ class ExecutionFrame { destroyBorrowedIValues(); } + template + auto withMemoryPlanner(CB&& cb) { + if (!layoutManager_) { + return std::forward(cb)(); + } + + LayoutManagerGuard guard(*layoutManager_); + return std::forward(cb)(); + } + std::vector tryMoveUserOutputs(); c10::IValue moveIValue(ValueId id) { @@ -79,14 +95,19 @@ class ExecutionFrame { return persistent_; } + C10_ALWAYS_INLINE bool isManagedValue(const ValueId id) const { + return layoutPlanner_ != nullptr && layoutPlanner_->is_managed(id); + } + void setPersistentIValue(ValueId id, c10::IValue ivalue) { setIValue(id, std::move(ivalue)); persistent_[id] = true; } - void releaseValue(ValueId id) { - CHECK(!persistent_[id]) << "Cannot release persistent value"; - allValues_[id] = c10::IValue(); + void releaseValueIfNeeded(ValueId id) { + if (!isManagedValue(id) && !persistent_[id]) { + allValues_[id] = c10::IValue(); + } } void destroyBorrowedIValues() { @@ -122,6 +143,9 @@ class ExecutionFrame { const Graph& graph_; WeightVersion weightVersion_ = -1; + std::unique_ptr layoutManager_; + LayoutPlanner* layoutPlanner_{nullptr}; + // All the intermediate values for the entire graph, including graph inputs // and outputs This table is fixed once constructed std::vector allValues_; diff --git a/torch/nativert/executor/SerialGraphExecutor.cpp b/torch/nativert/executor/SerialGraphExecutor.cpp index ccd52a4bd8d..017f4f178c8 100644 --- a/torch/nativert/executor/SerialGraphExecutor.cpp +++ b/torch/nativert/executor/SerialGraphExecutor.cpp @@ -14,19 +14,20 @@ std::vector SerialGraphExecutor::execute( std::vector SerialGraphExecutor::executeWithPrefilledFrame( ExecutionFrame& executionFrame) { - // Execute kernels for all nodes except prim.Input and prim.Output - for (NodeIndex nodeIdx = 1; nodeIdx < nodeKernels_.size() - 1; ++nodeIdx) { - nodeKernels_[nodeIdx]->compute(executionFrame); + executionFrame.withMemoryPlanner([&]() { + // Execute kernels for all nodes except prim.Input and prim.Output + for (NodeIndex nodeIdx = 1; nodeIdx < nodeKernels_.size() - 1; ++nodeIdx) { + nodeKernels_[nodeIdx]->compute(executionFrame); - // don't free intermediate values when static memory planning is enabled - if (executorConfig_.tryFreeUnmanagedValuesAfterUse) { - // Free the intermediate values that are no used anymore - for (const auto& valueKey : execPlan_->valuesToFree[nodeIdx]) { - executionFrame.releaseValue(valueKey); + // don't free intermediate values when static memory planning is enabled + if (executorConfig_.tryFreeUnmanagedValuesAfterUse) { + // Free the intermediate values that are no used anymore + for (const auto& valueKey : execPlan_->valuesToFree[nodeIdx]) { + executionFrame.releaseValueIfNeeded(valueKey); + } } } - } - + }); return executionFrame.tryMoveUserOutputs(); } diff --git a/torch/nativert/executor/memory/AliasAnalyzer.cpp b/torch/nativert/executor/memory/AliasAnalyzer.cpp index 3e6e3ee2d27..e56eb408531 100644 --- a/torch/nativert/executor/memory/AliasAnalyzer.cpp +++ b/torch/nativert/executor/memory/AliasAnalyzer.cpp @@ -162,12 +162,13 @@ void AliasAnalyzer::log_state() const { for (const auto* a : alias) { ss << a->name() << ", "; } - ss << "\n"; + ss << '\n'; } + ss << '\n'; + return ss.str(); - }() << std::endl - << std::flush; + }() << std::flush; } } // namespace torch::nativert diff --git a/torch/nativert/executor/memory/LayoutManager.cpp b/torch/nativert/executor/memory/LayoutManager.cpp index 322cef1c1d7..ae59289dbec 100644 --- a/torch/nativert/executor/memory/LayoutManager.cpp +++ b/torch/nativert/executor/memory/LayoutManager.cpp @@ -64,6 +64,7 @@ void LayoutManager::allocate_plan(const LayoutPlan& plan) { void* offset_ptr = layout_buffer_.get_ptr_with_offset(planned_allocation.offset); + // NOLINTNEXTLINE(bugprone-pointer-arithmetic-on-polymorphic-object) auto& storage = storage_buf[i]; // if the existing data ptr doesn't have an associated deleter then we @@ -124,12 +125,15 @@ void LayoutManager::ensure_managed_storages(bool allocate) { } else if ( C10_UNLIKELY( &storage != + // NOLINTNEXTLINE(bugprone-pointer-arithmetic-on-polymorphic-object) &storage_buf [i]) /* managed storage was replaced for some reason */) { storage.reset(); tensor->unsafeGetTensorImpl()->set_storage_keep_dtype(at::Storage( c10::intrusive_ptr::unsafe_adapt_non_heap_allocated( - &storage_buf[i], 1))); + // NOLINTNEXTLINE(bugprone-pointer-arithmetic-on-polymorphic-object) + &storage_buf[i], + 1))); } } } diff --git a/torch/nativert/executor/memory/LayoutPlanner.cpp b/torch/nativert/executor/memory/LayoutPlanner.cpp index 87913304d4d..5c45a08ea6f 100644 --- a/torch/nativert/executor/memory/LayoutPlanner.cpp +++ b/torch/nativert/executor/memory/LayoutPlanner.cpp @@ -80,7 +80,7 @@ LayoutPlanner::LayoutPlanner( continue; } - if (bool is_consumed = output->users().size() > 0; !is_consumed) { + if (bool is_not_consumed = output->users().empty(); is_not_consumed) { VLOG(1) << "not planning " << output->name() << " as it has no users"; continue; } @@ -154,7 +154,7 @@ void LayoutPlanner::initialize_vectors( planned_values_[i] = v->id(); planned_values_historical_max_nbytes_[i] = spec.size; - planned_allocation_specs_[i] = std::move(spec); + planned_allocation_specs_[i] = spec; i++; } @@ -178,9 +178,8 @@ void LayoutPlanner::start_worker_if_not_started() { // make sure plan is populated by the time this // returns for the first time :P create_plan(); - worker_ = std::thread([this]() { - run_periodic(std::bind(&LayoutPlanner::create_plan, this)); - }); + worker_ = + std::thread([this]() { run_periodic([this] { create_plan(); }); }); }); }