mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Updating ComputeCodegenUnboxedKernels to accept and write out kernel information to RegisterCodegenUnboxedKernels.cpp Differential Revision: [D46486195](https://our.internmc.facebook.com/intern/diff/D46486195/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/103113 Approved by: https://github.com/larryliu0820, https://github.com/kirklandsign
46 lines
1.1 KiB
C++
46 lines
1.1 KiB
C++
#include <c10/util/Exception.h>
|
|
#include <operator_registry.h>
|
|
|
|
namespace torch {
|
|
namespace executor {
|
|
|
|
KernelRegistry& getKernelRegistry() {
|
|
static KernelRegistry kernel_registry;
|
|
return kernel_registry;
|
|
}
|
|
|
|
bool register_kernels(const ArrayRef<Kernel>& kernels) {
|
|
return getKernelRegistry().register_kernels(kernels);
|
|
}
|
|
|
|
bool KernelRegistry::register_kernels(
|
|
const ArrayRef<Kernel>& kernels) {
|
|
for (const auto& kernel : kernels) {
|
|
this->kernels_map_[kernel.name_] = kernel.kernel_;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool hasKernelFn(const char* name) {
|
|
return getKernelRegistry().hasKernelFn(name);
|
|
}
|
|
|
|
bool KernelRegistry::hasKernelFn(const char* name) {
|
|
auto kernel = this->kernels_map_.find(name);
|
|
return kernel != this->kernels_map_.end();
|
|
}
|
|
|
|
KernelFunction& getKernelFn(const char* name) {
|
|
return getKernelRegistry().getKernelFn(name);
|
|
}
|
|
|
|
KernelFunction& KernelRegistry::getKernelFn(const char* name) {
|
|
auto kernel = this->kernels_map_.find(name);
|
|
TORCH_CHECK_MSG(kernel != this->kernels_map_.end(), "Kernel not found!");
|
|
return kernel->second;
|
|
}
|
|
|
|
|
|
} // namespace executor
|
|
} // namespace torch
|