pytorch/test/edge/operator_registry.cpp
Jack Khuu e9674d146c [Specialized Kernel] Propagate Specialized Kernel Support through ComputeCodegenUnboxedKernels (#103113)
Updating ComputeCodegenUnboxedKernels to accept and write out kernel information to RegisterCodegenUnboxedKernels.cpp

Differential Revision: [D46486195](https://our.internmc.facebook.com/intern/diff/D46486195/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103113
Approved by: https://github.com/larryliu0820, https://github.com/kirklandsign
2023-06-14 10:18:16 +00:00

46 lines
1.1 KiB
C++

#include <c10/util/Exception.h>
#include <operator_registry.h>
namespace torch {
namespace executor {
KernelRegistry& getKernelRegistry() {
static KernelRegistry kernel_registry;
return kernel_registry;
}
bool register_kernels(const ArrayRef<Kernel>& kernels) {
return getKernelRegistry().register_kernels(kernels);
}
bool KernelRegistry::register_kernels(
const ArrayRef<Kernel>& kernels) {
for (const auto& kernel : kernels) {
this->kernels_map_[kernel.name_] = kernel.kernel_;
}
return true;
}
bool hasKernelFn(const char* name) {
return getKernelRegistry().hasKernelFn(name);
}
bool KernelRegistry::hasKernelFn(const char* name) {
auto kernel = this->kernels_map_.find(name);
return kernel != this->kernels_map_.end();
}
KernelFunction& getKernelFn(const char* name) {
return getKernelRegistry().getKernelFn(name);
}
KernelFunction& KernelRegistry::getKernelFn(const char* name) {
auto kernel = this->kernels_map_.find(name);
TORCH_CHECK_MSG(kernel != this->kernels_map_.end(), "Kernel not found!");
return kernel->second;
}
} // namespace executor
} // namespace torch