mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[nativert] Move PrimKernelRegistry to PyTorch core (#156506)
Summary: Torch Native Runtime RFC: pytorch/rfcs#72 PrimKernelRegistry manages a small subset of kernel registry in NativeRT. Including ListPack, ListUnpack, Input, Output, VarConcat, VarStack Test Plan: Internal unittests Differential Revision: D77034945 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156506 Approved by: https://github.com/zhxchen17
This commit is contained in:
parent
fa0ea57f5e
commit
310e8361c5
|
|
@ -614,6 +614,7 @@ libtorch_nativert_sources = [
|
|||
"torch/nativert/executor/memory/GreedyBySize.cpp",
|
||||
"torch/nativert/executor/memory/Bump.cpp",
|
||||
"torch/nativert/kernels/CallTorchBindKernel.cpp",
|
||||
"torch/nativert/kernels/PrimKernelRegistry.cpp",
|
||||
]
|
||||
|
||||
torch_mobile_tracer_sources = [
|
||||
|
|
|
|||
172
torch/nativert/kernels/PrimKernelRegistry.cpp
Normal file
172
torch/nativert/kernels/PrimKernelRegistry.cpp
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
#include <ATen/record_function.h>
|
||||
|
||||
#include <ATen/CPUFunctions.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/jit/runtime/static/ops.h>
|
||||
|
||||
#include <c10/util/Enumerate.h>
|
||||
#include <torch/nativert/kernels/PrimKernelRegistry.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
C10_DEFINE_REGISTRY(PrimKernelRegistry, OpKernel, const Node*);
|
||||
|
||||
namespace {
|
||||
|
||||
class OpKernel_prim_listpack : public OpKernel {
|
||||
public:
|
||||
explicit OpKernel_prim_listpack(const Node* node)
|
||||
: OpKernel(
|
||||
node,
|
||||
std::nullopt,
|
||||
torch::nativert::OpKernelKind::kPrimKernel) {
|
||||
auto listType = node->outputs()[0]->type();
|
||||
switch (listType.kind()) {
|
||||
case Type::Kind::TensorList:
|
||||
type_ = c10::TensorType::get();
|
||||
break;
|
||||
case Type::Kind::SymIntList:
|
||||
type_ = c10::IntType::get();
|
||||
break;
|
||||
case Type::Kind::OptionalTensorList:
|
||||
type_ = c10::OptionalType::create(c10::TensorType::get());
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported list type: ", listType);
|
||||
}
|
||||
}
|
||||
|
||||
void computeInternal(ExecutionFrame& executionFrame) const override final {
|
||||
RECORD_USER_SCOPE("sigmoid::OpKernel_prim_listpack");
|
||||
c10::List<c10::IValue> list(type_);
|
||||
list.reserve(numInputs());
|
||||
for (size_t i = 0; i < numInputs(); ++i) {
|
||||
if (KernelInput(i).isNone()) {
|
||||
list.emplace_back();
|
||||
} else {
|
||||
list.push_back(KernelInput(i));
|
||||
}
|
||||
}
|
||||
KernelOutput(0) = std::move(list);
|
||||
}
|
||||
|
||||
private:
|
||||
c10::TypePtr type_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
C10_REGISTER_TYPED_CLASS(
|
||||
PrimKernelRegistry,
|
||||
"prim.ListPack",
|
||||
OpKernel_prim_listpack);
|
||||
|
||||
REGISTER_PRIM_KERNEL("prim.ListUnpack", prim_listunpack, {
|
||||
RECORD_USER_SCOPE("sigmoid::OpKernel_prim_listunpack");
|
||||
auto inputListRef = KernelInput(0).toListRef();
|
||||
for (const auto& [i, ivalue] : c10::enumerate(inputListRef)) {
|
||||
KernelOutput(i) = ivalue;
|
||||
}
|
||||
});
|
||||
|
||||
// Noop for input and output
|
||||
REGISTER_PRIM_KERNEL("prim.Input", prim_input, {});
|
||||
REGISTER_PRIM_KERNEL("prim.Output", prim_output, {});
|
||||
|
||||
namespace {
|
||||
|
||||
class OpKernel_variadic_concat : public OpKernel {
|
||||
public:
|
||||
explicit OpKernel_variadic_concat(const Node* node)
|
||||
: OpKernel(
|
||||
node,
|
||||
std::nullopt,
|
||||
torch::nativert::OpKernelKind::kPrimKernel) {
|
||||
dim_ = node_->attributes().size() > 0
|
||||
? constantToIValue(node_->getAttribute("dim").value).toInt()
|
||||
: 0;
|
||||
}
|
||||
void computeInternal(ExecutionFrame& executionFrame) const override final {
|
||||
{
|
||||
const size_t numNodeInps = numInputs();
|
||||
auto numCatInps = numNodeInps;
|
||||
auto dim = dim_;
|
||||
if (KernelInput(numCatInps - 1).isInt()) {
|
||||
dim = KernelInput(numCatInps - 1).toInt();
|
||||
numCatInps--;
|
||||
}
|
||||
std::vector<at::Tensor> inputs(numCatInps);
|
||||
for (const auto i : c10::irange(numCatInps)) {
|
||||
inputs[i] = KernelInput(i).toTensor();
|
||||
}
|
||||
|
||||
if (KernelOutput(0).isNone()) {
|
||||
KernelOutput(0) = at::cpu::cat(inputs, dim);
|
||||
return;
|
||||
}
|
||||
auto& out_t = KernelOutput(0).toTensor();
|
||||
fastResizeToZero(out_t);
|
||||
at::cpu::cat_outf(inputs, dim, out_t);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int dim_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
C10_REGISTER_TYPED_CLASS(
|
||||
PrimKernelRegistry,
|
||||
"prim.VarConcat",
|
||||
OpKernel_variadic_concat);
|
||||
|
||||
namespace {
|
||||
|
||||
class OpKernel_variadic_stack : public OpKernel {
|
||||
public:
|
||||
explicit OpKernel_variadic_stack(const Node* node)
|
||||
: OpKernel(
|
||||
node,
|
||||
std::nullopt,
|
||||
torch::nativert::OpKernelKind::kPrimKernel) {
|
||||
dim_ = node_->attributes().size() > 0
|
||||
? constantToIValue(node_->getAttribute("dim").value).toInt()
|
||||
: 0;
|
||||
}
|
||||
void computeInternal(ExecutionFrame& executionFrame) const override final {
|
||||
{
|
||||
const size_t numNodeInps = numInputs();
|
||||
auto numStackInps = numNodeInps;
|
||||
auto dim = dim_;
|
||||
if (KernelInput(numStackInps - 1).isInt()) {
|
||||
dim = KernelInput(numStackInps - 1).toInt();
|
||||
numStackInps--;
|
||||
}
|
||||
std::vector<at::Tensor> inputs(numStackInps);
|
||||
for (const auto i : c10::irange(numStackInps)) {
|
||||
inputs[i] = KernelInput(i).toTensor();
|
||||
}
|
||||
auto& out = KernelOutput(0);
|
||||
if (out.isNone()) {
|
||||
out = at::native::_stack_cpu(inputs, dim);
|
||||
return;
|
||||
}
|
||||
auto& out_t = out.toTensor();
|
||||
fastResizeToZero(out_t);
|
||||
at::native::_stack_out_cpu(inputs, dim, out_t);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t dim_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
C10_REGISTER_TYPED_CLASS(
|
||||
PrimKernelRegistry,
|
||||
"prim.VarStack",
|
||||
OpKernel_variadic_stack);
|
||||
|
||||
} // namespace torch::nativert
|
||||
40
torch/nativert/kernels/PrimKernelRegistry.h
Normal file
40
torch/nativert/kernels/PrimKernelRegistry.h
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/nativert/executor/OpKernel.h>
|
||||
#include <torch/nativert/graph/Graph.h>
|
||||
#include <torch/nativert/kernels/C10Kernel.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
#define KernelInput(id) input(id, executionFrame)
|
||||
#define KernelOutput(id) output(id, executionFrame)
|
||||
|
||||
TORCH_DECLARE_REGISTRY(PrimKernelRegistry, OpKernel, const Node*);
|
||||
|
||||
#define REGISTER_PRIM_KERNEL(name, id, ...) \
|
||||
class OpKernel_##id : public OpKernel { \
|
||||
public: \
|
||||
OpKernel_##id(const Node* node) \
|
||||
: OpKernel( \
|
||||
node, \
|
||||
std::nullopt, \
|
||||
torch::nativert::OpKernelKind::kPrimKernel) {} \
|
||||
void computeInternal( \
|
||||
ExecutionFrame& executionFrame) const override final { \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
}; \
|
||||
C10_REGISTER_TYPED_CLASS(PrimKernelRegistry, name, OpKernel_##id);
|
||||
|
||||
inline bool checkResizedDataPtr(at::Tensor& t) {
|
||||
auto const prev_data_ptr = t.data_ptr();
|
||||
t.resize_({0});
|
||||
return prev_data_ptr == t.data_ptr();
|
||||
}
|
||||
|
||||
inline void fastResizeToZero(at::Tensor& t) {
|
||||
t.unsafeGetTensorImpl()->set_sizes_contiguous({0});
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(checkResizedDataPtr(t));
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
Loading…
Reference in New Issue
Block a user