mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Strength matcher for StaticDispatch kernels: all input, output tensor must be on CPU, all Device-typed attribute must be CPU. Previously, we only check output tensor on CPU. This will miss catching the case where we do DeviceToHost aten._to_copy. Prepare for turning on static dispatch kernel by default. Test Plan: I should add some test before land. Rollback Plan: Differential Revision: D78747600 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159187 Approved by: https://github.com/dolpm
84 lines
2.7 KiB
C++
84 lines
2.7 KiB
C++
#pragma once
|
|
|
|
#include <memory>
|
|
|
|
#include <torch/csrc/inductor/aoti_torch/proxy_executor.h>
|
|
#include <torch/nativert/executor/DelegateExecutor.h>
|
|
#include <torch/nativert/executor/ExecutorConfig.h>
|
|
#include <torch/nativert/executor/GraphExecutorBase.h>
|
|
#include <torch/nativert/executor/OpKernel.h>
|
|
|
|
namespace torch::nativert {
|
|
|
|
struct ConstFoldingExecution {
|
|
std::unique_ptr<GraphExecutorBase> executor;
|
|
};
|
|
|
|
struct ExecutionKernels {
|
|
std::vector<std::unique_ptr<OpKernel>> nodeKernels;
|
|
std::vector<std::unique_ptr<DelegateExecutor>> delegateExecutors;
|
|
std::vector<ConstFoldingExecution> constFoldingExecutions;
|
|
};
|
|
|
|
class KernelFactoryHandler {
|
|
public:
|
|
using OpKernelPtr = std::unique_ptr<OpKernel>;
|
|
using DelegateExecutorPtr = std::unique_ptr<DelegateExecutor>;
|
|
using Matcher = c10::function_ref<
|
|
bool(const Node& node, const torch::nativert::ExecutorConfig&)>;
|
|
using Callback =
|
|
c10::function_ref<std::pair<OpKernelPtr, DelegateExecutorPtr>(
|
|
const Node&,
|
|
std::shared_ptr<Weights> weights,
|
|
const torch::nativert::ExecutorConfig& executorConfig,
|
|
caffe2::serialize::PyTorchStreamReader* pytorchStreamReader,
|
|
c10::Device targetDevice)>;
|
|
|
|
KernelFactoryHandler(Matcher matcher, Callback callback)
|
|
: matcher_(matcher), callback_(callback) {}
|
|
|
|
KernelFactoryHandler() = delete;
|
|
KernelFactoryHandler(const KernelFactoryHandler&) = default;
|
|
KernelFactoryHandler& operator=(const KernelFactoryHandler&) = default;
|
|
KernelFactoryHandler(KernelFactoryHandler&&) = default;
|
|
KernelFactoryHandler& operator=(KernelFactoryHandler&&) = default;
|
|
~KernelFactoryHandler() = default;
|
|
|
|
bool match(const Node& node, const torch::nativert::ExecutorConfig& config)
|
|
const {
|
|
return matcher_(node, config);
|
|
}
|
|
|
|
std::pair<OpKernelPtr, DelegateExecutorPtr> operator()(
|
|
const Node& node,
|
|
std::shared_ptr<Weights> weights,
|
|
const torch::nativert::ExecutorConfig& executorConfig,
|
|
caffe2::serialize::PyTorchStreamReader* pytorchStreamReader,
|
|
c10::Device targetDevice) const {
|
|
return callback_(
|
|
node, weights, executorConfig, pytorchStreamReader, targetDevice);
|
|
}
|
|
|
|
private:
|
|
Matcher matcher_;
|
|
Callback callback_;
|
|
};
|
|
|
|
class KernelFactory {
|
|
public:
|
|
KernelFactory() = default;
|
|
|
|
ExecutionKernels initializeNodeKernels(
|
|
const Graph& graph,
|
|
const std::shared_ptr<Weights>& weights,
|
|
const torch::nativert::ExecutorConfig& executorConfig,
|
|
const std::shared_ptr<caffe2::serialize::PyTorchStreamReader>&
|
|
pytorchStreamReader = nullptr);
|
|
|
|
static void registerHandler(
|
|
const std::string& name,
|
|
KernelFactoryHandler handler);
|
|
};
|
|
|
|
} // namespace torch::nativert
|