From 310e8361c565ca1602e719e4c812dc3931ec84d7 Mon Sep 17 00:00:00 2001 From: Yiming Zhou Date: Tue, 24 Jun 2025 21:42:38 +0000 Subject: [PATCH] [nativert] Move PrimKernelRegistry to PyTorch core (#156506) Summary: Torch Native Runtime RFC: pytorch/rfcs#72 PrimKernelRegistry manages a small subset of kernel registry in NativeRT. Including ListPack, ListUnpack, Input, Output, VarConcat, VarStack Test Plan: Internal unittests Differential Revision: D77034945 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156506 Approved by: https://github.com/zhxchen17 --- build_variables.bzl | 1 + torch/nativert/kernels/PrimKernelRegistry.cpp | 172 ++++++++++++++++++ torch/nativert/kernels/PrimKernelRegistry.h | 40 ++++ 3 files changed, 213 insertions(+) create mode 100644 torch/nativert/kernels/PrimKernelRegistry.cpp create mode 100644 torch/nativert/kernels/PrimKernelRegistry.h diff --git a/build_variables.bzl b/build_variables.bzl index 28d0c8bf7a7..cdb75d5746d 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -614,6 +614,7 @@ libtorch_nativert_sources = [ "torch/nativert/executor/memory/GreedyBySize.cpp", "torch/nativert/executor/memory/Bump.cpp", "torch/nativert/kernels/CallTorchBindKernel.cpp", + "torch/nativert/kernels/PrimKernelRegistry.cpp", ] torch_mobile_tracer_sources = [ diff --git a/torch/nativert/kernels/PrimKernelRegistry.cpp b/torch/nativert/kernels/PrimKernelRegistry.cpp new file mode 100644 index 00000000000..6a400f93fea --- /dev/null +++ b/torch/nativert/kernels/PrimKernelRegistry.cpp @@ -0,0 +1,172 @@ +#include + +#include +#include +#include +#include + +#include +#include + +namespace torch::nativert { + +C10_DEFINE_REGISTRY(PrimKernelRegistry, OpKernel, const Node*); + +namespace { + +class OpKernel_prim_listpack : public OpKernel { + public: + explicit OpKernel_prim_listpack(const Node* node) + : OpKernel( + node, + std::nullopt, + torch::nativert::OpKernelKind::kPrimKernel) { + auto listType = node->outputs()[0]->type(); + switch (listType.kind()) { + case Type::Kind::TensorList: + type_ = c10::TensorType::get(); + break; + case Type::Kind::SymIntList: + type_ = c10::IntType::get(); + break; + case Type::Kind::OptionalTensorList: + type_ = c10::OptionalType::create(c10::TensorType::get()); + break; + default: + TORCH_CHECK(false, "Unsupported list type: ", listType); + } + } + + void computeInternal(ExecutionFrame& executionFrame) const override final { + RECORD_USER_SCOPE("sigmoid::OpKernel_prim_listpack"); + c10::List list(type_); + list.reserve(numInputs()); + for (size_t i = 0; i < numInputs(); ++i) { + if (KernelInput(i).isNone()) { + list.emplace_back(); + } else { + list.push_back(KernelInput(i)); + } + } + KernelOutput(0) = std::move(list); + } + + private: + c10::TypePtr type_; +}; + +} // namespace + +C10_REGISTER_TYPED_CLASS( + PrimKernelRegistry, + "prim.ListPack", + OpKernel_prim_listpack); + +REGISTER_PRIM_KERNEL("prim.ListUnpack", prim_listunpack, { + RECORD_USER_SCOPE("sigmoid::OpKernel_prim_listunpack"); + auto inputListRef = KernelInput(0).toListRef(); + for (const auto& [i, ivalue] : c10::enumerate(inputListRef)) { + KernelOutput(i) = ivalue; + } +}); + +// Noop for input and output +REGISTER_PRIM_KERNEL("prim.Input", prim_input, {}); +REGISTER_PRIM_KERNEL("prim.Output", prim_output, {}); + +namespace { + +class OpKernel_variadic_concat : public OpKernel { + public: + explicit OpKernel_variadic_concat(const Node* node) + : OpKernel( + node, + std::nullopt, + torch::nativert::OpKernelKind::kPrimKernel) { + dim_ = node_->attributes().size() > 0 + ? constantToIValue(node_->getAttribute("dim").value).toInt() + : 0; + } + void computeInternal(ExecutionFrame& executionFrame) const override final { + { + const size_t numNodeInps = numInputs(); + auto numCatInps = numNodeInps; + auto dim = dim_; + if (KernelInput(numCatInps - 1).isInt()) { + dim = KernelInput(numCatInps - 1).toInt(); + numCatInps--; + } + std::vector inputs(numCatInps); + for (const auto i : c10::irange(numCatInps)) { + inputs[i] = KernelInput(i).toTensor(); + } + + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::cat(inputs, dim); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::cpu::cat_outf(inputs, dim, out_t); + } + } + + private: + int dim_; +}; + +} // namespace + +C10_REGISTER_TYPED_CLASS( + PrimKernelRegistry, + "prim.VarConcat", + OpKernel_variadic_concat); + +namespace { + +class OpKernel_variadic_stack : public OpKernel { + public: + explicit OpKernel_variadic_stack(const Node* node) + : OpKernel( + node, + std::nullopt, + torch::nativert::OpKernelKind::kPrimKernel) { + dim_ = node_->attributes().size() > 0 + ? constantToIValue(node_->getAttribute("dim").value).toInt() + : 0; + } + void computeInternal(ExecutionFrame& executionFrame) const override final { + { + const size_t numNodeInps = numInputs(); + auto numStackInps = numNodeInps; + auto dim = dim_; + if (KernelInput(numStackInps - 1).isInt()) { + dim = KernelInput(numStackInps - 1).toInt(); + numStackInps--; + } + std::vector inputs(numStackInps); + for (const auto i : c10::irange(numStackInps)) { + inputs[i] = KernelInput(i).toTensor(); + } + auto& out = KernelOutput(0); + if (out.isNone()) { + out = at::native::_stack_cpu(inputs, dim); + return; + } + auto& out_t = out.toTensor(); + fastResizeToZero(out_t); + at::native::_stack_out_cpu(inputs, dim, out_t); + } + } + + private: + int64_t dim_; +}; +} // namespace + +C10_REGISTER_TYPED_CLASS( + PrimKernelRegistry, + "prim.VarStack", + OpKernel_variadic_stack); + +} // namespace torch::nativert diff --git a/torch/nativert/kernels/PrimKernelRegistry.h b/torch/nativert/kernels/PrimKernelRegistry.h new file mode 100644 index 00000000000..791b2a8bb18 --- /dev/null +++ b/torch/nativert/kernels/PrimKernelRegistry.h @@ -0,0 +1,40 @@ +#pragma once + +#include +#include +#include + +namespace torch::nativert { + +#define KernelInput(id) input(id, executionFrame) +#define KernelOutput(id) output(id, executionFrame) + +TORCH_DECLARE_REGISTRY(PrimKernelRegistry, OpKernel, const Node*); + +#define REGISTER_PRIM_KERNEL(name, id, ...) \ + class OpKernel_##id : public OpKernel { \ + public: \ + OpKernel_##id(const Node* node) \ + : OpKernel( \ + node, \ + std::nullopt, \ + torch::nativert::OpKernelKind::kPrimKernel) {} \ + void computeInternal( \ + ExecutionFrame& executionFrame) const override final { \ + __VA_ARGS__; \ + } \ + }; \ + C10_REGISTER_TYPED_CLASS(PrimKernelRegistry, name, OpKernel_##id); + +inline bool checkResizedDataPtr(at::Tensor& t) { + auto const prev_data_ptr = t.data_ptr(); + t.resize_({0}); + return prev_data_ptr == t.data_ptr(); +} + +inline void fastResizeToZero(at::Tensor& t) { + t.unsafeGetTensorImpl()->set_sizes_contiguous({0}); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(checkResizedDataPtr(t)); +} + +} // namespace torch::nativert