diff --git a/torch/nativert/graph/GraphPasses.cpp b/torch/nativert/graph/GraphPasses.cpp index 981a63815db..6cb378af80d 100644 --- a/torch/nativert/graph/GraphPasses.cpp +++ b/torch/nativert/graph/GraphPasses.cpp @@ -101,7 +101,10 @@ std::string selectScalarOverloadName(const Node& node) { "floor_divide_out", "_conj"}; std::vector atoms = c10::split(node.target(), '.'); - TORCH_CHECK(atoms.size() >= 3); + + if (atoms.size() < 3) { + return ""; + } std::string ns = std::string{atoms[atoms.size() - 3]}; std::string opName = std::string{atoms[atoms.size() - 2]}; @@ -110,7 +113,7 @@ std::string selectScalarOverloadName(const Node& node) { overloadName != "Tensor_mode") { return overloadName; } - if (allowed.find(std::string{opName}) == allowed.end()) { + if (allowed.find(opName) == allowed.end()) { return overloadName; } auto op = c10::Dispatcher::singleton().findSchemaOrThrow( diff --git a/torch/nativert/kernels/KernelFactory.cpp b/torch/nativert/kernels/KernelFactory.cpp index 88740d2aa45..d433a5dc0a2 100644 --- a/torch/nativert/kernels/KernelFactory.cpp +++ b/torch/nativert/kernels/KernelFactory.cpp @@ -148,7 +148,7 @@ ExecutionKernels KernelFactory::initializeNodeKernels( bool matched = false; for (const auto& [_, handler] : handlers) { - if (handler.match(node, executorConfig, targetDevice)) { + if (handler.match(node, executorConfig)) { auto [kernel, delegate] = handler( node, weights, @@ -253,7 +253,8 @@ ExecutionKernels KernelFactory::initializeNodeKernels( } } - if (executorConfig.enableStaticCPUKernels) { + if (executorConfig.enableStaticCPUKernels && + !opsWithoutStaticDispatchCount.empty()) { std::stringstream ss; for (const auto& [op, count] : opsWithoutStaticDispatchCount) { ss << op << ": " << count << ", \n"; diff --git a/torch/nativert/kernels/KernelFactory.h b/torch/nativert/kernels/KernelFactory.h index 77308e0055b..21c9c5dffde 100644 --- a/torch/nativert/kernels/KernelFactory.h +++ b/torch/nativert/kernels/KernelFactory.h @@ -24,10 +24,8 @@ class KernelFactoryHandler { public: using OpKernelPtr = std::unique_ptr; using DelegateExecutorPtr = std::unique_ptr; - using Matcher = c10::function_ref; + using Matcher = c10::function_ref< + bool(const Node& node, const torch::nativert::ExecutorConfig&)>; using Callback = c10::function_ref( const Node&, @@ -46,11 +44,9 @@ class KernelFactoryHandler { KernelFactoryHandler& operator=(KernelFactoryHandler&&) = default; ~KernelFactoryHandler() = default; - bool match( - const Node& node, - const torch::nativert::ExecutorConfig& config, - c10::Device device) const { - return matcher_(node, config, device); + bool match(const Node& node, const torch::nativert::ExecutorConfig& config) + const { + return matcher_(node, config); } std::pair operator()(