mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This PR exposes some MIOpen symbols, namely: 1. `miopenDataType_t getMiopenDataType(const at::Tensor& tensor)` 2. `miopenHandle_t getMiopenHandle()` 3. `class TensorDescriptor` 4. `class Descriptor` 5. `class FilterDescriptor` 6. `struct ConvolutionDescriptor` 7. `struct DropoutDescriptor` 8. `struct RNNDescriptor` to enable adding extensions that make use of them. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154545 Approved by: https://github.com/jeffdaily, https://github.com/Skylion007 Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
parent
83a0e4e6f9
commit
08fdc64c86
|
|
@ -5,6 +5,7 @@
|
|||
#include <ATen/miopen/miopen-wrapper.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <c10/macros/Export.h>
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
|
|
@ -37,9 +38,9 @@ struct DescriptorDeleter {
|
|||
// initialized the first time you call set() or any other initializing
|
||||
// function.
|
||||
template <typename T, miopenStatus_t (*ctor)(T**), miopenStatus_t (*dtor)(T*)>
|
||||
class Descriptor
|
||||
{
|
||||
public:
|
||||
// NOLINTNEXTLINE(bugprone-exception-escape)
|
||||
class TORCH_CUDA_CPP_API Descriptor {
|
||||
public:
|
||||
// Use desc() to access the underlying descriptor pointer in
|
||||
// a read-only fashion. Most client code should use this.
|
||||
// If the descriptor was never initialized, this will return
|
||||
|
|
@ -55,7 +56,7 @@ public:
|
|||
protected:
|
||||
void init() {
|
||||
if (desc_ == nullptr) {
|
||||
T* raw_desc;
|
||||
T* raw_desc = nullptr;
|
||||
MIOPEN_CHECK(ctor(&raw_desc));
|
||||
desc_.reset(raw_desc);
|
||||
}
|
||||
|
|
@ -64,13 +65,12 @@ private:
|
|||
std::unique_ptr<T, DescriptorDeleter<T, dtor>> desc_;
|
||||
};
|
||||
|
||||
class TensorDescriptor
|
||||
: public Descriptor<miopenTensorDescriptor,
|
||||
&miopenCreateTensorDescriptor,
|
||||
&miopenDestroyTensorDescriptor>
|
||||
{
|
||||
public:
|
||||
TensorDescriptor() {}
|
||||
class TORCH_CUDA_CPP_API TensorDescriptor : public Descriptor<
|
||||
miopenTensorDescriptor,
|
||||
&miopenCreateTensorDescriptor,
|
||||
&miopenDestroyTensorDescriptor> {
|
||||
public:
|
||||
TensorDescriptor() = default;
|
||||
explicit TensorDescriptor(const at::Tensor &t, size_t pad = 0) {
|
||||
set(t, pad);
|
||||
}
|
||||
|
|
@ -88,11 +88,10 @@ private:
|
|||
|
||||
std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d);
|
||||
|
||||
class FilterDescriptor
|
||||
: public Descriptor<miopenTensorDescriptor,
|
||||
&miopenCreateTensorDescriptor,
|
||||
&miopenDestroyTensorDescriptor>
|
||||
{
|
||||
class TORCH_CUDA_CPP_API FilterDescriptor : public Descriptor<
|
||||
miopenTensorDescriptor,
|
||||
&miopenCreateTensorDescriptor,
|
||||
&miopenDestroyTensorDescriptor> {
|
||||
public:
|
||||
void set(const at::Tensor &t, int64_t pad = 0) {
|
||||
set(t, at::MemoryFormat::Contiguous, pad);
|
||||
|
|
@ -106,11 +105,11 @@ private:
|
|||
}
|
||||
};
|
||||
|
||||
struct ConvolutionDescriptor
|
||||
: public Descriptor<miopenConvolutionDescriptor,
|
||||
&miopenCreateConvolutionDescriptor,
|
||||
&miopenDestroyConvolutionDescriptor>
|
||||
{
|
||||
struct TORCH_CUDA_CPP_API ConvolutionDescriptor
|
||||
: public Descriptor<
|
||||
miopenConvolutionDescriptor,
|
||||
&miopenCreateConvolutionDescriptor,
|
||||
&miopenDestroyConvolutionDescriptor> {
|
||||
void set(miopenDataType_t dataType, miopenConvolutionMode_t c_mode, int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups, bool benchmark, bool deterministic) {
|
||||
MIOPEN_CHECK(miopenInitConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale, c_mode));
|
||||
MIOPEN_CHECK(miopenSetConvolutionGroupCount(mut_desc(), groups));
|
||||
|
|
@ -121,11 +120,12 @@ struct ConvolutionDescriptor
|
|||
}
|
||||
};
|
||||
|
||||
struct DropoutDescriptor
|
||||
: public Descriptor<miopenDropoutDescriptor,
|
||||
&miopenCreateDropoutDescriptor,
|
||||
&miopenDestroyDropoutDescriptor>
|
||||
{
|
||||
// NOLINTNEXTLINE(bugprone-exception-escape)
|
||||
struct TORCH_CUDA_CPP_API DropoutDescriptor
|
||||
: public Descriptor<
|
||||
miopenDropoutDescriptor,
|
||||
&miopenCreateDropoutDescriptor,
|
||||
&miopenDestroyDropoutDescriptor> {
|
||||
void set(miopenHandle_t handle, float dropout, void* states, size_t stateSizeInBytes,
|
||||
unsigned long long seed, bool use_mask, bool state_evo, miopenRNGType_t rng_mode) {
|
||||
MIOPEN_CHECK(miopenSetDropoutDescriptor(mut_desc(), handle, dropout, states, stateSizeInBytes, seed, use_mask, state_evo, rng_mode));
|
||||
|
|
@ -137,7 +137,7 @@ struct DropoutDescriptor
|
|||
}
|
||||
};
|
||||
|
||||
struct RNNDescriptor
|
||||
struct TORCH_CUDA_CPP_API RNNDescriptor
|
||||
: public Descriptor<miopenRNNDescriptor,
|
||||
&miopenCreateRNNDescriptor,
|
||||
&miopenDestroyRNNDescriptor>
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/miopen/miopen-wrapper.h>
|
||||
#include <c10/macros/Export.h>
|
||||
|
||||
namespace at { namespace native {
|
||||
namespace at::native {
|
||||
|
||||
miopenHandle_t getMiopenHandle();
|
||||
|
||||
}} // namespace
|
||||
TORCH_CUDA_CPP_API miopenHandle_t getMiopenHandle();
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/miopen/miopen-wrapper.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/miopen/miopen-wrapper.h>
|
||||
#include <c10/macros/Export.h>
|
||||
|
||||
namespace at { namespace native {
|
||||
namespace at::native {
|
||||
|
||||
miopenDataType_t getMiopenDataType(const at::Tensor& tensor);
|
||||
TORCH_CUDA_CPP_API miopenDataType_t getMiopenDataType(const at::Tensor& tensor);
|
||||
|
||||
int64_t miopen_version();
|
||||
|
||||
}} // namespace at::miopen
|
||||
} // namespace at::native
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user