[nativert][triton] improve hardware registration (#162499)

Summary: att

Test Plan:
ci

Rollback Plan:

Differential Revision: D82031814

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162499
Approved by: https://github.com/angelayi
This commit is contained in:
dolpm 2025-09-10 04:52:57 +00:00 committed by PyTorch MergeBot
parent 96ef26f71a
commit 1c16c18a53
8 changed files with 142 additions and 88 deletions

View File

@ -552,6 +552,11 @@ if(USE_CUDA OR USE_ROCM)
append_filelist("libtorch_cuda_core_sources" Caffe2_GPU_HIP_JIT_FUSERS_SRCS) append_filelist("libtorch_cuda_core_sources" Caffe2_GPU_HIP_JIT_FUSERS_SRCS)
endif() endif()
if(USE_CUDA)
# eventually do rocm
append_filelist("libtorch_nativert_cuda_sources" Caffe2_GPU_SRCS)
endif()
if(USE_CUDA) if(USE_CUDA)
list(APPEND Caffe2_GPU_CU_SRCS ${Caffe2_GPU_HIP_JIT_FUSERS_SRCS}) list(APPEND Caffe2_GPU_CU_SRCS ${Caffe2_GPU_HIP_JIT_FUSERS_SRCS})
add_library(caffe2_nvrtc SHARED ${ATen_NVRTC_STUB_SRCS}) add_library(caffe2_nvrtc SHARED ${ATen_NVRTC_STUB_SRCS})

View File

@ -40,21 +40,24 @@ set(NATIVERT_TEST_SRCS
${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/GraphPasses.cpp ${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/GraphPasses.cpp
${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/PassManager.cpp ${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/PassManager.cpp
${TORCH_ROOT}/torch/nativert/kernels/KernelHandlerRegistry.cpp ${TORCH_ROOT}/torch/nativert/kernels/KernelHandlerRegistry.cpp
${TORCH_ROOT}/torch/nativert/kernels/TritonKernel.cpp
${TORCH_ROOT}/torch/nativert/executor/triton/CpuTritonKernelManager.cpp ${TORCH_ROOT}/torch/nativert/executor/triton/CpuTritonKernelManager.cpp
${TORCH_ROOT}/torch/nativert/kernels/TritonKernel.cpp
${TORCH_ROOT}/torch/nativert/executor/DelegateExecutor.cpp ${TORCH_ROOT}/torch/nativert/executor/DelegateExecutor.cpp
) )
if(USE_CUDA) if(USE_CUDA)
list(APPEND NATIVERT_TEST_SRCS ${TORCH_ROOT}/torch/nativert/executor/triton/CudaTritonKernelManager.cpp) list(APPEND NATIVERT_TEST_SRCS ${TORCH_ROOT}/torch/nativert/executor/triton/CudaTritonKernelManager.cpp)
endif(MSVC) endif()
add_executable(test_nativert add_executable(test_nativert
${TORCH_ROOT}/test/cpp/common/main.cpp ${TORCH_ROOT}/test/cpp/common/main.cpp
${NATIVERT_TEST_SRCS} ${NATIVERT_TEST_SRCS}
) )
if(MSVC)
target_compile_definitions(test_nativert PRIVATE NATIVERT_MSVC_TEST)
endif()
# TODO temporary until we can delete the old gtest polyfills. # TODO temporary until we can delete the old gtest polyfills.
target_compile_definitions(test_nativert PRIVATE USE_GTEST) target_compile_definitions(test_nativert PRIVATE USE_GTEST)

View File

@ -6,9 +6,20 @@ using namespace ::testing;
using namespace torch::nativert; using namespace torch::nativert;
TEST(TritonKernelManagerRegistrationTests, TestRegister) { TEST(TritonKernelManagerRegistrationTests, TestRegister) {
#ifndef USE_CUDA EXPECT_TRUE(TritonKernelManagerRegistry()->Has(at::kCPU));
EXPECT_TRUE(create_cuda_triton_kernel_manager == nullptr);
#ifdef USE_CUDA
#ifdef USE_ROCM
EXPECT_TRUE(TritonKernelManagerRegistry()->Has(at::kHIP));
EXPECT_FALSE(TritonKernelManagerRegistry()->Has(at::kCUDA));
#else #else
EXPECT_FALSE(create_cuda_triton_kernel_manager == nullptr); EXPECT_TRUE(TritonKernelManagerRegistry()->Has(at::kCUDA));
EXPECT_FALSE(TritonKernelManagerRegistry()->Has(at::kHIP));
#endif // USE_ROCM
#else
EXPECT_FALSE(TritonKernelManagerRegistry()->Has(at::kCUDA));
EXPECT_FALSE(TritonKernelManagerRegistry()->Has(at::kHIP));
#endif // USE_CUDA #endif // USE_CUDA
} }

View File

@ -1,5 +1,6 @@
#include <torch/nativert/executor/triton/CpuTritonKernelManager.h> #include <torch/nativert/executor/triton/TritonKernelManager.h>
#include <c10/util/FbcodeMaps.h>
#include <c10/util/Logging.h> #include <c10/util/Logging.h>
#ifndef _WIN32 #ifndef _WIN32
@ -35,6 +36,43 @@ char* _dlerror() {
} // namespace } // namespace
typedef void* kernel_ptr_t;
typedef void (
*launcher_ptr_t)(uint32_t, uint32_t, uint32_t, void**, kernel_ptr_t);
struct DlcloseDeleter {
void operator()(void* p) const {
if (p) {
#if defined(_WIN32)
TORCH_CHECK(false, "Windows is not supported");
#else
dlclose(p);
#endif
}
}
};
class CpuTritonKernelManager final : public TritonKernelManager {
public:
CpuTritonKernelManager(
std::string kernel_name,
std::string kernel_bin_path,
std::string kernel_launcher_bin_path);
~CpuTritonKernelManager() final = default;
void launch(const LaunchParams& launch_params, void** args) final;
private:
void load();
kernel_ptr_t kernel_fn_{nullptr};
launcher_ptr_t launcher_fn_{nullptr};
std::unique_ptr<void, DlcloseDeleter> kernel_handle_{nullptr};
std::unique_ptr<void, DlcloseDeleter> launcher_handle_{nullptr};
std::string kernel_launcher_bin_path_;
};
CpuTritonKernelManager::CpuTritonKernelManager( CpuTritonKernelManager::CpuTritonKernelManager(
std::string kernel_name, std::string kernel_name,
std::string kernel_bin_path, std::string kernel_bin_path,
@ -88,4 +126,21 @@ void CpuTritonKernelManager::launch(
kernel_fn_); kernel_fn_);
} }
namespace {
std::unique_ptr<TritonKernelManager> create_cpu_triton_kernel_manager(
std::string kernel_name,
std::string kernel_bin_path,
std::string kernel_launcher_bin_path) {
return std::make_unique<CpuTritonKernelManager>(
std::move(kernel_name),
std::move(kernel_bin_path),
std::move(kernel_launcher_bin_path));
}
} // namespace
C10_REGISTER_TYPED_CREATOR(
TritonKernelManagerRegistry,
at::kCPU,
create_cpu_triton_kernel_manager)
} // namespace torch::nativert } // namespace torch::nativert

View File

@ -1,51 +0,0 @@
#pragma once
#include <torch/nativert/executor/triton/TritonKernelManager.h>
#include <c10/core/Device.h>
#include <c10/util/FbcodeMaps.h>
#ifndef _WIN32
#include <dlfcn.h>
#endif
typedef void* kernel_ptr_t;
typedef void (
*launcher_ptr_t)(uint32_t, uint32_t, uint32_t, void**, kernel_ptr_t);
namespace torch::nativert {
struct DlcloseDeleter {
void operator()(void* p) const {
if (p) {
#if defined(_WIN32)
TORCH_CHECK(false, "Windows is not supported");
#else
dlclose(p);
#endif
}
}
};
class CpuTritonKernelManager final : public TritonKernelManager {
public:
CpuTritonKernelManager(
std::string kernel_name,
std::string kernel_bin_path,
std::string kernel_launcher_bin_path);
~CpuTritonKernelManager() final = default;
void launch(const LaunchParams& launch_params, void** args) final;
private:
void load();
kernel_ptr_t kernel_fn_{nullptr};
launcher_ptr_t launcher_fn_{nullptr};
std::unique_ptr<void, DlcloseDeleter> kernel_handle_{nullptr};
std::unique_ptr<void, DlcloseDeleter> launcher_handle_{nullptr};
std::string kernel_launcher_bin_path_;
};
} // namespace torch::nativert

View File

@ -29,7 +29,7 @@ namespace torch::nativert {
class CudaKernelInputs final : public KernelInputs { class CudaKernelInputs final : public KernelInputs {
public: public:
CudaKernelInputs(size_t num_args, size_t num_attrs) CudaKernelInputs(size_t num_args, size_t num_attrs)
: KernelInputs(num_args, num_attrs), arg_ptrs_(num_args) {}; : KernelInputs(num_args, num_attrs), arg_ptrs_(num_args) {}
~CudaKernelInputs() final = default; ~CudaKernelInputs() final = default;
void add_arg(void* arg) override { void add_arg(void* arg) override {
@ -73,7 +73,7 @@ CudaTritonKernelManager::CudaTritonKernelManager(
TORCH_CHECK( TORCH_CHECK(
at::globalContext().hasCUDA() || at::globalContext().hasHIP(), at::globalContext().hasCUDA() || at::globalContext().hasHIP(),
"cuda or hip required"); "cuda or hip required");
}; }
CudaTritonKernelManager::~CudaTritonKernelManager() { CudaTritonKernelManager::~CudaTritonKernelManager() {
const auto& nvrtc = get_nvrtc(); const auto& nvrtc = get_nvrtc();
@ -137,19 +137,31 @@ void CudaTritonKernelManager::launch(
nullptr)); nullptr));
} }
static std::unique_ptr<TritonKernelManager> _create_cuda_triton_kernel_manager( namespace {
std::unique_ptr<TritonKernelManager> create_cuda_triton_kernel_manager(
std::string kernel_name, std::string kernel_name,
std::string kernel_bin_path) { std::string kernel_bin_path,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
[[maybe_unused]] std::string kernel_launcher_bin_path) {
return std::make_unique<CudaTritonKernelManager>( return std::make_unique<CudaTritonKernelManager>(
std::move(kernel_name), std::move(kernel_bin_path)); std::move(kernel_name), std::move(kernel_bin_path));
} }
} // namespace
#ifdef USE_ROCM
C10_REGISTER_TYPED_CREATOR(
TritonKernelManagerRegistry,
at::kHIP,
create_cuda_triton_kernel_manager)
#else
C10_REGISTER_TYPED_CREATOR(
TritonKernelManagerRegistry,
at::kCUDA,
create_cuda_triton_kernel_manager)
#endif // USE_ROCM
} // namespace torch::nativert } // namespace torch::nativert
namespace {
static bool _initialized_cuda_triton_kernel_manager = []() {
torch::nativert::create_cuda_triton_kernel_manager =
&torch::nativert::_create_cuda_triton_kernel_manager;
return true;
}();
} // namespace

View File

@ -2,7 +2,9 @@
#include <string> #include <string>
#include <c10/core/DeviceType.h>
#include <c10/util/Exception.h> #include <c10/util/Exception.h>
#include <c10/util/Registry.h>
namespace torch::nativert { namespace torch::nativert {
@ -69,7 +71,13 @@ class TritonKernelManager {
std::string kernel_name_, kernel_bin_path_; std::string kernel_name_, kernel_bin_path_;
}; };
inline std::unique_ptr<TritonKernelManager> ( C10_DECLARE_TYPED_REGISTRY(
*create_cuda_triton_kernel_manager)(std::string, std::string) = nullptr; TritonKernelManagerRegistry,
c10::DeviceType,
TritonKernelManager,
std::unique_ptr,
std::string /* kernel_name */,
std::string /* kernel_bin_path */,
std::string /* kernel_launcher_bin_path */);
} // namespace torch::nativert } // namespace torch::nativert

View File

@ -16,10 +16,20 @@
#include <ATen/ops/empty.h> #include <ATen/ops/empty.h>
#endif #endif
#include <torch/nativert/executor/triton/CpuTritonKernelManager.h>
namespace torch::nativert { namespace torch::nativert {
// in this case, we want to use the symbol from torch_cpu.dll
#ifndef NATIVERT_MSVC_TEST
C10_DEFINE_TYPED_REGISTRY(
TritonKernelManagerRegistry,
c10::DeviceType,
TritonKernelManager,
std::unique_ptr,
std::string /* kernel_name */,
std::string /* kernel_bin_path */,
std::string /* kernel_launcher_bin_path */)
#endif
TritonKernel::TritonKernel( TritonKernel::TritonKernel(
const Node* node, const Node* node,
caffe2::serialize::PyTorchStreamReader* reader) caffe2::serialize::PyTorchStreamReader* reader)
@ -74,27 +84,28 @@ TritonKernel::TritonKernel(
auto tmp_dir = extractToTemporaryFolder(*reader, kernel_prefix) + "/"; auto tmp_dir = extractToTemporaryFolder(*reader, kernel_prefix) + "/";
if (reader->hasRecord(kernel_prefix + "/" + kernel_name + ".cubin")) { if (reader->hasRecord(kernel_prefix + "/" + kernel_name + ".cubin")) {
loader_ = TritonKernelManagerRegistry()->Create(
at::kCUDA, kernel_name, tmp_dir + kernel_name + ".cubin", "");
TORCH_CHECK( TORCH_CHECK(
create_cuda_triton_kernel_manager != nullptr, loader_ != nullptr,
"couldn't find cuda loader -- is this a gpu build?"); "couldn't find cuda loader -- is this a gpu build?");
loader_ = create_cuda_triton_kernel_manager( } else if (reader->hasRecord(kernel_prefix + "/" + kernel_name + ".hsaco")) {
kernel_name, tmp_dir + kernel_name + ".cubin"); loader_ = TritonKernelManagerRegistry()->Create(
} at::kHIP, kernel_name, tmp_dir + kernel_name + ".hsaco", "");
if (reader->hasRecord(kernel_prefix + "/" + kernel_name + ".hsaco")) {
TORCH_CHECK( TORCH_CHECK(
create_cuda_triton_kernel_manager != nullptr, loader_ != nullptr,
"couldn't find cuda loader -- is this a gpu build?"); "couldn't find cuda loader -- is this a gpu build?");
loader_ = create_cuda_triton_kernel_manager( } else {
kernel_name, tmp_dir + kernel_name + ".hsaco"); loader_ = TritonKernelManagerRegistry()->Create(
} at::kCPU,
if (loader_ == nullptr) {
loader_ = std::unique_ptr<TritonKernelManager>(new CpuTritonKernelManager(
kernel_name, kernel_name,
tmp_dir + kernel_name + ".so", tmp_dir + kernel_name + ".so",
tmp_dir + kernel_name + ".launcher.so")); tmp_dir + kernel_name + ".launcher.so");
} }
TORCH_CHECK(
loader_ != nullptr,
"couldn't find triton kernel loader -- are you trying to run gpu kernels on a cpu build?");
} }
TritonKernel::~TritonKernel() = default; TritonKernel::~TritonKernel() = default;