mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Test Plan: CI Rollback Plan: Differential Revision: D77693984 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157514 Approved by: https://github.com/zhxchen17
207 lines
6.6 KiB
C++
207 lines
6.6 KiB
C++
#pragma once
|
|
|
|
#include <atomic>
|
|
#include <memory>
|
|
|
|
#include <c10/util/FbcodeMaps.h>
|
|
#include <c10/util/Logging.h>
|
|
#include <c10/util/Semaphore.h>
|
|
#include <c10/util/Synchronized.h>
|
|
|
|
#include <torch/nativert/detail/ITree.h>
|
|
#include <torch/nativert/detail/MPMCQueue.h>
|
|
#include <torch/nativert/executor/ConstantFolder.h>
|
|
#include <torch/nativert/executor/DelegateExecutor.h>
|
|
#include <torch/nativert/executor/ExecutionPlanner.h>
|
|
#include <torch/nativert/executor/ExecutorConfig.h>
|
|
#include <torch/nativert/executor/GraphExecutorBase.h>
|
|
#include <torch/nativert/executor/Placement.h>
|
|
#include <torch/nativert/executor/memory/FunctionSchema.h>
|
|
#include <torch/nativert/executor/memory/LayoutPlanner.h>
|
|
#include <torch/nativert/graph/Graph.h>
|
|
#include <torch/nativert/graph/GraphSignature.h>
|
|
#include <torch/nativert/kernels/KernelFactory.h>
|
|
|
|
namespace torch::nativert {
|
|
|
|
using namespace torch::nativert::detail;
|
|
|
|
struct DistributedRunConfig;
|
|
|
|
/**
|
|
* A very dumb executor. Basically just runs each node in order and contains a
|
|
* giant unordered map for every intermediate, no optimizations applied.
|
|
*/
|
|
class Executor {
|
|
class ExecutorFrameDeleter {
|
|
public:
|
|
explicit ExecutorFrameDeleter(Executor& e) : e_(&e) {}
|
|
ExecutorFrameDeleter(ExecutorFrameDeleter&&) = default;
|
|
ExecutorFrameDeleter& operator=(ExecutorFrameDeleter&&) = default;
|
|
ExecutorFrameDeleter(const ExecutorFrameDeleter&) = default;
|
|
ExecutorFrameDeleter& operator=(const ExecutorFrameDeleter&) = default;
|
|
~ExecutorFrameDeleter() = default;
|
|
|
|
void operator()(ExecutionFrame* p) {
|
|
e_->returnExecutorFrameToPool(std::unique_ptr<ExecutionFrame>(p));
|
|
}
|
|
|
|
private:
|
|
Executor* e_;
|
|
};
|
|
class ExecutorFramePtr {
|
|
public:
|
|
ExecutorFramePtr(std::unique_ptr<ExecutionFrame> ptr, Executor& e)
|
|
: ptr_(std::unique_ptr<ExecutionFrame, ExecutorFrameDeleter>(
|
|
ptr.release(),
|
|
ExecutorFrameDeleter{e})) {}
|
|
ExecutorFramePtr() = delete;
|
|
ExecutorFramePtr(ExecutorFramePtr&&) = default;
|
|
ExecutorFramePtr& operator=(ExecutorFramePtr&&) = default;
|
|
ExecutorFramePtr(const ExecutorFramePtr&) = delete;
|
|
ExecutorFramePtr& operator=(const ExecutorFramePtr&) = delete;
|
|
~ExecutorFramePtr() = default;
|
|
|
|
ExecutionFrame& operator*() {
|
|
return *ptr_;
|
|
}
|
|
|
|
ExecutionFrame* operator->() {
|
|
return ptr_.get();
|
|
}
|
|
|
|
private:
|
|
std::unique_ptr<ExecutionFrame, ExecutorFrameDeleter> ptr_;
|
|
};
|
|
|
|
public:
|
|
// Constrcutor used for Inference Path
|
|
Executor(
|
|
torch::nativert::ExecutorConfig executorConfig,
|
|
std::shared_ptr<Graph> graph,
|
|
std::shared_ptr<Weights> weights,
|
|
const Placement& placement = Placement(),
|
|
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
|
|
pytorchStreamReader = nullptr,
|
|
const MakeProxyExecutorFn& makeProxyExecutorFunc = nullptr);
|
|
|
|
std::shared_ptr<Weights> getWeights() {
|
|
std::shared_ptr<Weights> ret;
|
|
weights_.withLock([&](auto& w) { ret = w; });
|
|
return ret;
|
|
}
|
|
|
|
void processWeights(std::shared_ptr<Weights> weights);
|
|
void atomicSwapWeights(std::shared_ptr<Weights> weights);
|
|
|
|
// This API only returns the flattened UserOutputs,
|
|
// intended to be used for Inference path
|
|
// TODO Investigate whether we should remove this, still seems
|
|
// useful for testing.
|
|
std::vector<c10::IValue> execute(std::vector<c10::IValue> inputs);
|
|
|
|
std::vector<c10::IValue> execute(
|
|
const std::vector<c10::IValue>& args,
|
|
const std::unordered_map<std::string, c10::IValue>& kwargs,
|
|
const ITreeSpec& inputTreeSpec);
|
|
|
|
ProfileMetrics benchmarkIndividualNodes(
|
|
std::vector<std::vector<c10::IValue>> inputsList,
|
|
const uint32_t warmupRuns,
|
|
const uint32_t mainRuns);
|
|
|
|
const torch::nativert::GraphSignature& graphSignature() const {
|
|
return graph_->signature();
|
|
}
|
|
|
|
static std::string className() {
|
|
return "Executor.v0";
|
|
}
|
|
|
|
const torch::nativert::ExecutorConfig& executorConfig() const {
|
|
return executorConfig_;
|
|
}
|
|
|
|
std::vector<DelegateExecutor*> getDelegates();
|
|
|
|
// Get the number of execution frames in the pool
|
|
int getNumExecutionFrames() const {
|
|
return numExecutionFrames_.load();
|
|
}
|
|
|
|
static c10::FastMap<std::string /* target */, torch::nativert::FunctionSchema>
|
|
getKernelSchemas(const std::vector<std::unique_ptr<OpKernel>>& kernels);
|
|
|
|
protected:
|
|
torch::nativert::ExecutorConfig executorConfig_;
|
|
|
|
std::shared_ptr<Graph> graph_;
|
|
|
|
// manages the parameters, buffers and tensor constants
|
|
c10::Synchronized<std::shared_ptr<Weights>> weights_;
|
|
|
|
void initialize(
|
|
std::shared_ptr<Weights> weights,
|
|
std::shared_ptr<caffe2::serialize::PyTorchStreamReader>
|
|
pytorchStreamReader);
|
|
|
|
ExecutorFramePtr getExecutorFrameFromPool();
|
|
void returnExecutorFrameToPool(std::unique_ptr<ExecutionFrame> frame);
|
|
|
|
// Clears stale execution frames from the pool
|
|
void clearStaleExecutionFrames();
|
|
|
|
private:
|
|
// Structure to track execution frame usage
|
|
struct ExecutionFrameEntry {
|
|
bool used{false};
|
|
std::unique_ptr<ExecutionFrame> frame;
|
|
|
|
// Add move constructor and assignment operator
|
|
ExecutionFrameEntry() = default;
|
|
ExecutionFrameEntry(ExecutionFrameEntry&& other) noexcept
|
|
: used(other.used), frame(std::move(other.frame)) {}
|
|
ExecutionFrameEntry& operator=(ExecutionFrameEntry&& other) noexcept {
|
|
used = other.used;
|
|
frame = std::move(other.frame);
|
|
return *this;
|
|
}
|
|
// Delete copy constructor and assignment operator
|
|
ExecutionFrameEntry(const ExecutionFrameEntry&) = delete;
|
|
ExecutionFrameEntry& operator=(const ExecutionFrameEntry&) = delete;
|
|
};
|
|
|
|
void maybeRunConstantFolding(std::shared_ptr<Weights> weights);
|
|
void validateInputs(const std::vector<c10::IValue>& inputs) const;
|
|
|
|
// Helper method to get current timestamp in seconds
|
|
int64_t getCurrentTimestampSeconds() const;
|
|
|
|
std::unique_ptr<GraphExecutorBase> graphExecutor_;
|
|
|
|
const Placement placement_;
|
|
|
|
// NOTE: delegateExecutors_ is used by nodeKernels_ inside graphExecutor_.
|
|
std::vector<std::unique_ptr<DelegateExecutor>> delegateExecutors_;
|
|
|
|
std::vector<ConstFoldingExecution> constFoldingExecutions_;
|
|
|
|
std::optional<ConstantFolder> constantFolder_;
|
|
|
|
MakeProxyExecutorFn makeProxyExecutorFunc_;
|
|
|
|
c10::Semaphore sem_;
|
|
torch::nativert::detail::MPMCQueue<std::unique_ptr<ExecutionFrame>>
|
|
executionFrames_;
|
|
torch::nativert::detail::MPMCQueue<ExecutionFrameEntry>
|
|
clearedExecutionFrames_;
|
|
std::atomic_int64_t numExecutionFrames_;
|
|
|
|
std::unique_ptr<LayoutPlanner> layoutPlanner_;
|
|
std::atomic_int64_t lastClearedTimestamp_;
|
|
std::mutex cleanupLock_;
|
|
std::atomic_bool clearingInProgress_{false};
|
|
};
|
|
|
|
} // namespace torch::nativert
|