#include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch::nativert { namespace { c10::Device inferTargetDevice( const Node& node, const std::unordered_map& tensorValuesMeta) { if (node.target() == "prim.Input" || node.target() == "prim.Output") { return c10::Device(c10::DeviceType::CPU); } std::vector devices; for (auto& output : node.outputs()) { if (output->type() == Type::Kind::Tensor) { auto it = tensorValuesMeta.find(std::string{output->name()}); if (it != tensorValuesMeta.end()) { devices.emplace_back(it->second.device()); } } else if (output->type() == Type::Kind::TensorList) { for (const auto& el : output->getListElements()) { auto it = tensorValuesMeta.find(std::string{el->name()}); if (it != tensorValuesMeta.end()) { devices.emplace_back(it->second.device()); } } } } if (devices.empty()) { return c10::Device(c10::DeviceType::CPU); } else { for (size_t i = 1; i < devices.size(); ++i) { if (!torch::nativert::isSameDevice(devices[0], devices[i])) { LOG(WARNING) << "Node " << node << " has outputs on multiple devices: " << devices[0] << " and " << devices[i]; } } return devices[0]; } } } // namespace inline constexpr std::array kSymIntOps = { "_operator.floordiv", "_operator.mod", "torch.sym_int", "torch.sym_float", "torch.sym_ite", "torch.sym_max", "torch.sym_min", }; inline constexpr std::array kSymBoolOps = { "_operator.eq", "_operator.ne", "_operator.le", "_operator.ge", "_operator.lt", "_operator.gt", "_operator.and_", "torch.sym_not", }; inline constexpr std::array kSymFloatOps = { "torch._sym_sqrt", "math.trunc", "_operator.neg", "_operator.truediv", }; inline constexpr std::array kScalarBinaryOps = { "_operator.mul", "_operator.add", "_operator.sub", "_operator.pow", }; namespace { struct KernelFactoryRegistry { std::unordered_map handlers; }; c10::Synchronized& getKernelFactoryRegistry() { static auto* registry = new c10::Synchronized(); return *registry; } } // namespace void KernelFactory::registerHandler( const std::string& name, KernelFactoryHandler handler) { auto& registry = getKernelFactoryRegistry(); registry.withLock([&](auto&& reg) { if (reg.handlers.find(name) != reg.handlers.end()) { TORCH_CHECK(false, "Handler for ", name, " already registered"); } reg.handlers.emplace(name, std::move(handler)); }); } ExecutionKernels KernelFactory::initializeNodeKernels( const Graph& graph, const std::shared_ptr& weights, const torch::nativert::ExecutorConfig& executorConfig, const std::shared_ptr& pytorchStreamReader) { std::vector> nodeKernels; std::vector> delegateExecutors; std::vector constFoldingExecutions; std::unordered_map opsWithoutStaticDispatchCount; VLOG(1) << fmt::format( "PrimKernelRegistry: {}", fmt::join(PrimKernelRegistry()->Keys(), ", ")); std::unordered_map handlers; getKernelFactoryRegistry().withLock( [&](auto&& reg) { handlers = reg.handlers; }); for (const auto& node : graph.nodes()) { std::string target = std::string(node.target()); c10::Device targetDevice = inferTargetDevice(node, graph.tensorValuesMeta()); bool matched = false; for (const auto& [_, handler] : handlers) { if (handler.match(node, executorConfig, targetDevice)) { auto [kernel, delegate] = handler( node, weights, executorConfig, pytorchStreamReader.get(), targetDevice); if (kernel) { nodeKernels.push_back(std::move(kernel)); } if (delegate) { delegateExecutors.push_back(std::move(delegate)); } matched = true; break; } } if (matched) { continue; } if (PrimKernelRegistry()->Has(target)) { nodeKernels.push_back(PrimKernelRegistry()->Create(target, &node)); } else if (c10::starts_with( node.target(), "torch.ops.higher_order.call_torchbind")) { nodeKernels.push_back(std::make_unique(&node)); } else if ( c10::starts_with( node.target(), "torch.ops.higher_order.auto_functionalized") || c10::starts_with( // TODO Remove this condition once the old // pt2 archives are expired. node.target(), "torch._higher_order_ops.auto_functionalize.auto_functionalized")) { nodeKernels.push_back( std::make_unique(&node)); } else if ( std::find( std::begin(kSymIntOps), std::end(kSymIntOps), node.target()) != std::end(kSymIntOps)) { nodeKernels.push_back(std::make_unique(&node)); } else if ( std::find( std::begin(kSymBoolOps), std::end(kSymBoolOps), node.target()) != std::end(kSymBoolOps)) { nodeKernels.push_back(std::make_unique(&node)); } else if ( std::find( std::begin(kSymFloatOps), std::end(kSymFloatOps), node.target()) != std::end(kSymFloatOps)) { nodeKernels.push_back(std::make_unique(&node)); } else if ( std::find( std::begin(kScalarBinaryOps), std::end(kScalarBinaryOps), node.target()) != std::end(kScalarBinaryOps)) { nodeKernels.push_back(std::make_unique(&node)); } else if (c10::starts_with(node.target(), "torch.ops.higher_order")) { std::vector> graphExecutors; for (const auto& attr : node.attributes()) { if (std::holds_alternative>(attr.value)) { const auto& subgraph = std::get>(attr.value); auto executionKernels = initializeNodeKernels(*subgraph, weights, executorConfig); TORCH_CHECK( executionKernels.delegateExecutors.empty(), "HigherOrderKernel does not support delegates"); TORCH_CHECK( executionKernels.constFoldingExecutions.empty(), "HigherOrderKernel does not support const folding"); if (executorConfig.maxParallelOps > 1) { graphExecutors.emplace_back( std::unique_ptr(new ParallelGraphExecutor( *subgraph, std::move(executionKernels.nodeKernels), executorConfig))); } else { graphExecutors.emplace_back(std::unique_ptr( new torch::nativert::SerialGraphExecutor( *subgraph, std::move(executionKernels.nodeKernels), executorConfig))); } } } if (node.target() == "torch.ops.higher_order.run_const_graph") { constFoldingExecutions.push_back( ConstFoldingExecution{std::move(graphExecutors[0])}); } nodeKernels.push_back(std::make_unique( &node, std::move(graphExecutors))); } else if (c10::starts_with(node.target(), "torch.ops")) { nodeKernels.push_back(std::make_unique(&node)); std::string opName = std::string(node.target()); if (opsWithoutStaticDispatchCount.find(opName) == opsWithoutStaticDispatchCount.end()) { opsWithoutStaticDispatchCount[opName] = 0; } opsWithoutStaticDispatchCount[opName] += 1; } else { TORCH_CHECK(false, "Unsupported operator: ", target); } } if (executorConfig.enableStaticCPUKernels) { std::stringstream ss; for (const auto& [op, count] : opsWithoutStaticDispatchCount) { ss << op << ": " << count << ", \n"; } LOG(WARNING) << "Following ops are missing static dispatched kernels: \n" << ss.str(); } return { std::move(nodeKernels), std::move(delegateExecutors), std::move(constFoldingExecutions)}; } } // namespace torch::nativert