mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[NativeRT] Strengthen matcher check for StaticDispatch kernel (#159187)
Summary: Strength matcher for StaticDispatch kernels: all input, output tensor must be on CPU, all Device-typed attribute must be CPU. Previously, we only check output tensor on CPU. This will miss catching the case where we do DeviceToHost aten._to_copy. Prepare for turning on static dispatch kernel by default. Test Plan: I should add some test before land. Rollback Plan: Differential Revision: D78747600 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159187 Approved by: https://github.com/dolpm
This commit is contained in:
parent
67e68e0785
commit
e924df23a6
|
|
@ -101,7 +101,10 @@ std::string selectScalarOverloadName(const Node& node) {
|
|||
"floor_divide_out",
|
||||
"_conj"};
|
||||
std::vector<std::string_view> atoms = c10::split(node.target(), '.');
|
||||
TORCH_CHECK(atoms.size() >= 3);
|
||||
|
||||
if (atoms.size() < 3) {
|
||||
return "";
|
||||
}
|
||||
|
||||
std::string ns = std::string{atoms[atoms.size() - 3]};
|
||||
std::string opName = std::string{atoms[atoms.size() - 2]};
|
||||
|
|
@ -110,7 +113,7 @@ std::string selectScalarOverloadName(const Node& node) {
|
|||
overloadName != "Tensor_mode") {
|
||||
return overloadName;
|
||||
}
|
||||
if (allowed.find(std::string{opName}) == allowed.end()) {
|
||||
if (allowed.find(opName) == allowed.end()) {
|
||||
return overloadName;
|
||||
}
|
||||
auto op = c10::Dispatcher::singleton().findSchemaOrThrow(
|
||||
|
|
|
|||
|
|
@ -148,7 +148,7 @@ ExecutionKernels KernelFactory::initializeNodeKernels(
|
|||
|
||||
bool matched = false;
|
||||
for (const auto& [_, handler] : handlers) {
|
||||
if (handler.match(node, executorConfig, targetDevice)) {
|
||||
if (handler.match(node, executorConfig)) {
|
||||
auto [kernel, delegate] = handler(
|
||||
node,
|
||||
weights,
|
||||
|
|
@ -253,7 +253,8 @@ ExecutionKernels KernelFactory::initializeNodeKernels(
|
|||
}
|
||||
}
|
||||
|
||||
if (executorConfig.enableStaticCPUKernels) {
|
||||
if (executorConfig.enableStaticCPUKernels &&
|
||||
!opsWithoutStaticDispatchCount.empty()) {
|
||||
std::stringstream ss;
|
||||
for (const auto& [op, count] : opsWithoutStaticDispatchCount) {
|
||||
ss << op << ": " << count << ", \n";
|
||||
|
|
|
|||
|
|
@ -24,10 +24,8 @@ class KernelFactoryHandler {
|
|||
public:
|
||||
using OpKernelPtr = std::unique_ptr<OpKernel>;
|
||||
using DelegateExecutorPtr = std::unique_ptr<DelegateExecutor>;
|
||||
using Matcher = c10::function_ref<bool(
|
||||
const Node& node,
|
||||
const torch::nativert::ExecutorConfig&,
|
||||
c10::Device)>;
|
||||
using Matcher = c10::function_ref<
|
||||
bool(const Node& node, const torch::nativert::ExecutorConfig&)>;
|
||||
using Callback =
|
||||
c10::function_ref<std::pair<OpKernelPtr, DelegateExecutorPtr>(
|
||||
const Node&,
|
||||
|
|
@ -46,11 +44,9 @@ class KernelFactoryHandler {
|
|||
KernelFactoryHandler& operator=(KernelFactoryHandler&&) = default;
|
||||
~KernelFactoryHandler() = default;
|
||||
|
||||
bool match(
|
||||
const Node& node,
|
||||
const torch::nativert::ExecutorConfig& config,
|
||||
c10::Device device) const {
|
||||
return matcher_(node, config, device);
|
||||
bool match(const Node& node, const torch::nativert::ExecutorConfig& config)
|
||||
const {
|
||||
return matcher_(node, config);
|
||||
}
|
||||
|
||||
std::pair<OpKernelPtr, DelegateExecutorPtr> operator()(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user