[ROCm] Exposing Some MIOpen Symbols (#2176) (#154545)

This PR exposes some MIOpen symbols, namely:

1. `miopenDataType_t getMiopenDataType(const at::Tensor& tensor)`
2. `miopenHandle_t getMiopenHandle()`
3. `class TensorDescriptor`
4. `class Descriptor`
5. `class FilterDescriptor`
6. `struct ConvolutionDescriptor`
7. `struct DropoutDescriptor`
8. `struct RNNDescriptor`

to enable adding extensions that make use of them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154545
Approved by: https://github.com/jeffdaily, https://github.com/Skylion007

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
Arash Pakbin 2025-05-29 21:10:45 +00:00 committed by PyTorch MergeBot
parent 83a0e4e6f9
commit 08fdc64c86
3 changed files with 36 additions and 35 deletions

View File

@ -5,6 +5,7 @@
#include <ATen/miopen/miopen-wrapper.h>
#include <ATen/core/Tensor.h>
#include <ATen/TensorUtils.h>
#include <c10/macros/Export.h>
namespace at { namespace native {
@ -37,9 +38,9 @@ struct DescriptorDeleter {
// initialized the first time you call set() or any other initializing
// function.
template <typename T, miopenStatus_t (*ctor)(T**), miopenStatus_t (*dtor)(T*)>
class Descriptor
{
public:
// NOLINTNEXTLINE(bugprone-exception-escape)
class TORCH_CUDA_CPP_API Descriptor {
public:
// Use desc() to access the underlying descriptor pointer in
// a read-only fashion. Most client code should use this.
// If the descriptor was never initialized, this will return
@ -55,7 +56,7 @@ public:
protected:
void init() {
if (desc_ == nullptr) {
T* raw_desc;
T* raw_desc = nullptr;
MIOPEN_CHECK(ctor(&raw_desc));
desc_.reset(raw_desc);
}
@ -64,13 +65,12 @@ private:
std::unique_ptr<T, DescriptorDeleter<T, dtor>> desc_;
};
class TensorDescriptor
: public Descriptor<miopenTensorDescriptor,
&miopenCreateTensorDescriptor,
&miopenDestroyTensorDescriptor>
{
public:
TensorDescriptor() {}
class TORCH_CUDA_CPP_API TensorDescriptor : public Descriptor<
miopenTensorDescriptor,
&miopenCreateTensorDescriptor,
&miopenDestroyTensorDescriptor> {
public:
TensorDescriptor() = default;
explicit TensorDescriptor(const at::Tensor &t, size_t pad = 0) {
set(t, pad);
}
@ -88,11 +88,10 @@ private:
std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d);
class FilterDescriptor
: public Descriptor<miopenTensorDescriptor,
&miopenCreateTensorDescriptor,
&miopenDestroyTensorDescriptor>
{
class TORCH_CUDA_CPP_API FilterDescriptor : public Descriptor<
miopenTensorDescriptor,
&miopenCreateTensorDescriptor,
&miopenDestroyTensorDescriptor> {
public:
void set(const at::Tensor &t, int64_t pad = 0) {
set(t, at::MemoryFormat::Contiguous, pad);
@ -106,11 +105,11 @@ private:
}
};
struct ConvolutionDescriptor
: public Descriptor<miopenConvolutionDescriptor,
&miopenCreateConvolutionDescriptor,
&miopenDestroyConvolutionDescriptor>
{
struct TORCH_CUDA_CPP_API ConvolutionDescriptor
: public Descriptor<
miopenConvolutionDescriptor,
&miopenCreateConvolutionDescriptor,
&miopenDestroyConvolutionDescriptor> {
void set(miopenDataType_t dataType, miopenConvolutionMode_t c_mode, int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups, bool benchmark, bool deterministic) {
MIOPEN_CHECK(miopenInitConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale, c_mode));
MIOPEN_CHECK(miopenSetConvolutionGroupCount(mut_desc(), groups));
@ -121,11 +120,12 @@ struct ConvolutionDescriptor
}
};
struct DropoutDescriptor
: public Descriptor<miopenDropoutDescriptor,
&miopenCreateDropoutDescriptor,
&miopenDestroyDropoutDescriptor>
{
// NOLINTNEXTLINE(bugprone-exception-escape)
struct TORCH_CUDA_CPP_API DropoutDescriptor
: public Descriptor<
miopenDropoutDescriptor,
&miopenCreateDropoutDescriptor,
&miopenDestroyDropoutDescriptor> {
void set(miopenHandle_t handle, float dropout, void* states, size_t stateSizeInBytes,
unsigned long long seed, bool use_mask, bool state_evo, miopenRNGType_t rng_mode) {
MIOPEN_CHECK(miopenSetDropoutDescriptor(mut_desc(), handle, dropout, states, stateSizeInBytes, seed, use_mask, state_evo, rng_mode));
@ -137,7 +137,7 @@ struct DropoutDescriptor
}
};
struct RNNDescriptor
struct TORCH_CUDA_CPP_API RNNDescriptor
: public Descriptor<miopenRNNDescriptor,
&miopenCreateRNNDescriptor,
&miopenDestroyRNNDescriptor>

View File

@ -1,9 +1,9 @@
#pragma once
#include <ATen/miopen/miopen-wrapper.h>
#include <c10/macros/Export.h>
namespace at { namespace native {
namespace at::native {
miopenHandle_t getMiopenHandle();
}} // namespace
TORCH_CUDA_CPP_API miopenHandle_t getMiopenHandle();
} // namespace at::native

View File

@ -1,12 +1,13 @@
#pragma once
#include <ATen/miopen/miopen-wrapper.h>
#include <ATen/Tensor.h>
#include <ATen/miopen/miopen-wrapper.h>
#include <c10/macros/Export.h>
namespace at { namespace native {
namespace at::native {
miopenDataType_t getMiopenDataType(const at::Tensor& tensor);
TORCH_CUDA_CPP_API miopenDataType_t getMiopenDataType(const at::Tensor& tensor);
int64_t miopen_version();
}} // namespace at::miopen
} // namespace at::native