[NativeRT] Strengthen matcher check for StaticDispatch kernel (#159187)

Summary:
Strength matcher for StaticDispatch kernels: all input, output tensor must be on CPU, all Device-typed attribute must be CPU.

Previously, we only check output tensor on CPU. This will miss catching the case where we do DeviceToHost aten._to_copy.

Prepare for turning on static dispatch kernel by default.

Test Plan:
I should add some test before land.

Rollback Plan:

Differential Revision: D78747600

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159187
Approved by: https://github.com/dolpm
This commit is contained in:
Sherlock Huang 2025-07-29 04:03:49 +00:00 committed by PyTorch MergeBot
parent 67e68e0785
commit e924df23a6
3 changed files with 13 additions and 13 deletions

View File

@ -101,7 +101,10 @@ std::string selectScalarOverloadName(const Node& node) {
"floor_divide_out",
"_conj"};
std::vector<std::string_view> atoms = c10::split(node.target(), '.');
TORCH_CHECK(atoms.size() >= 3);
if (atoms.size() < 3) {
return "";
}
std::string ns = std::string{atoms[atoms.size() - 3]};
std::string opName = std::string{atoms[atoms.size() - 2]};
@ -110,7 +113,7 @@ std::string selectScalarOverloadName(const Node& node) {
overloadName != "Tensor_mode") {
return overloadName;
}
if (allowed.find(std::string{opName}) == allowed.end()) {
if (allowed.find(opName) == allowed.end()) {
return overloadName;
}
auto op = c10::Dispatcher::singleton().findSchemaOrThrow(

View File

@ -148,7 +148,7 @@ ExecutionKernels KernelFactory::initializeNodeKernels(
bool matched = false;
for (const auto& [_, handler] : handlers) {
if (handler.match(node, executorConfig, targetDevice)) {
if (handler.match(node, executorConfig)) {
auto [kernel, delegate] = handler(
node,
weights,
@ -253,7 +253,8 @@ ExecutionKernels KernelFactory::initializeNodeKernels(
}
}
if (executorConfig.enableStaticCPUKernels) {
if (executorConfig.enableStaticCPUKernels &&
!opsWithoutStaticDispatchCount.empty()) {
std::stringstream ss;
for (const auto& [op, count] : opsWithoutStaticDispatchCount) {
ss << op << ": " << count << ", \n";

View File

@ -24,10 +24,8 @@ class KernelFactoryHandler {
public:
using OpKernelPtr = std::unique_ptr<OpKernel>;
using DelegateExecutorPtr = std::unique_ptr<DelegateExecutor>;
using Matcher = c10::function_ref<bool(
const Node& node,
const torch::nativert::ExecutorConfig&,
c10::Device)>;
using Matcher = c10::function_ref<
bool(const Node& node, const torch::nativert::ExecutorConfig&)>;
using Callback =
c10::function_ref<std::pair<OpKernelPtr, DelegateExecutorPtr>(
const Node&,
@ -46,11 +44,9 @@ class KernelFactoryHandler {
KernelFactoryHandler& operator=(KernelFactoryHandler&&) = default;
~KernelFactoryHandler() = default;
bool match(
const Node& node,
const torch::nativert::ExecutorConfig& config,
c10::Device device) const {
return matcher_(node, config, device);
bool match(const Node& node, const torch::nativert::ExecutorConfig& config)
const {
return matcher_(node, config);
}
std::pair<OpKernelPtr, DelegateExecutorPtr> operator()(