[ROCm] Implemented dropout usage for RNN with MIOpen backend (#144572)

This PR fixes https://github.com/pytorch/pytorch/issues/107183 for ROCm.

Implemented the usage of new RNN descriptor for MIOpen backend that takes into account dropout rate value using dropout descriptor. This fixes associated test_RNN_dropout_state test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144572
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
Iurii Paikov 2025-04-25 21:06:45 +00:00 committed by PyTorch MergeBot
parent 2c5c793085
commit 1aa971a3bb
2 changed files with 106 additions and 7 deletions

View File

@ -121,6 +121,21 @@ struct ConvolutionDescriptor
} }
}; };
struct DropoutDescriptor
: public Descriptor<miopenDropoutDescriptor,
&miopenCreateDropoutDescriptor,
&miopenDestroyDropoutDescriptor>
{
void set(miopenHandle_t handle, float dropout, void* states, size_t stateSizeInBytes,
unsigned long long seed, bool use_mask, bool state_evo, miopenRNGType_t rng_mode) {
MIOPEN_CHECK(miopenSetDropoutDescriptor(mut_desc(), handle, dropout, states, stateSizeInBytes, seed, use_mask, state_evo, rng_mode));
}
void restore(miopenHandle_t handle, float dropout, void* states, size_t stateSizeInBytes,
unsigned long long seed, bool use_mask, bool state_evo, miopenRNGType_t rng_mode) {
MIOPEN_CHECK(miopenRestoreDropoutDescriptor(mut_desc(), handle, dropout, states, stateSizeInBytes, seed, use_mask, state_evo, rng_mode));
}
};
struct RNNDescriptor struct RNNDescriptor
: public Descriptor<miopenRNNDescriptor, : public Descriptor<miopenRNNDescriptor,
@ -128,9 +143,14 @@ struct RNNDescriptor
&miopenDestroyRNNDescriptor> &miopenDestroyRNNDescriptor>
{ {
void set(int64_t hidden_size, int64_t num_layers, miopenRNNInputMode_t input_mode, miopenRNNDirectionMode_t direction, miopenRNNMode_t rnn_mode, void set(int64_t hidden_size, int64_t num_layers, miopenRNNInputMode_t input_mode, miopenRNNDirectionMode_t direction, miopenRNNMode_t rnn_mode,
miopenRNNBiasMode_t bias_mode, miopenRNNAlgo_t algorithm, miopenDataType_t datatype) { miopenRNNBiasMode_t bias_mode, miopenRNNAlgo_t algorithm, miopenDataType_t datatype) {
MIOPEN_CHECK(miopenSetRNNDescriptor(mut_desc(), hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algorithm, datatype)); MIOPEN_CHECK(miopenSetRNNDescriptor(mut_desc(), hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algorithm, datatype));
} }
void setWithDropout(DropoutDescriptor& dropout_desc, int64_t hidden_size, int64_t num_layers, miopenRNNInputMode_t input_mode, miopenRNNDirectionMode_t direction,
miopenRNNMode_t rnn_mode, miopenRNNBiasMode_t bias_mode, miopenRNNAlgo_t algorithm, miopenDataType_t datatype) {
MIOPEN_CHECK(miopenSetRNNDescriptor_V2(mut_desc(), hidden_size, num_layers, dropout_desc.mut_desc(), input_mode, direction, rnn_mode, bias_mode, algorithm, datatype));
}
}; };
union Constant union Constant

View File

@ -57,6 +57,10 @@ namespace at::native {
#include <ATen/TensorUtils.h> #include <ATen/TensorUtils.h>
#include <c10/hip/HIPCachingAllocator.h>
#include <rocrand/rocrand_xorwow.h>
#include <functional> #include <functional>
#include <iterator> #include <iterator>
#include <sstream> #include <sstream>
@ -66,12 +70,35 @@ namespace at::native {
#include <stdint.h> #include <stdint.h>
#include <unordered_map> #include <unordered_map>
namespace at { namespace native { namespace at::native {
namespace {
struct DropoutState {
DropoutState(size_t size) : size(size), data(NULL) {
data = c10::hip::HIPCachingAllocator::raw_alloc(size);
}
DropoutState(const DropoutState&) = delete;
DropoutState(DropoutState&&) = default;
DropoutState& operator=(DropoutState&&) = default;
~DropoutState() {
if (data) {
c10::hip::HIPCachingAllocator::raw_delete(data);
}
}
size_t size;
void* data;
};
} // anonymous
//RNNDescriptor. //RNNDescriptor.
struct RNNDescriptorParams { struct RNNDescriptorParams {
int64_t hidden_size; int64_t hidden_size;
int64_t num_layers; int64_t num_layers;
double dropout_rate;
uint64_t dropout_seed;
miopenRNNDirectionMode_t direction; miopenRNNDirectionMode_t direction;
miopenRNNMode_t rnn_mode; miopenRNNMode_t rnn_mode;
miopenDataType_t datatype; miopenDataType_t datatype;
@ -114,6 +141,12 @@ struct RNNDescriptorParams {
} }
} }
void set_dropout(double dropout_rate, uint64_t dropout_seed = 0) {
this->dropout_rate = dropout_rate;
// TODO: Implement seed setting for RNN dropout
this->dropout_seed = dropout_seed;
}
void set(int64_t mode, int64_t hidden_size, int64_t num_layers, bool bidirectional, miopenDataType_t datatype, miopenRNNBiasMode_t bias_mode) { void set(int64_t mode, int64_t hidden_size, int64_t num_layers, bool bidirectional, miopenDataType_t datatype, miopenRNNBiasMode_t bias_mode) {
this->set_mode(mode); this->set_mode(mode);
this->hidden_size = hidden_size; this->hidden_size = hidden_size;
@ -128,12 +161,18 @@ struct RNNDescriptorParams {
rnn_desc.set(hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algo, datatype); rnn_desc.set(hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algo, datatype);
return rnn_desc; return rnn_desc;
} }
RNNDescriptor descriptorWithDropout(DropoutDescriptor& dropout_desc) const {
RNNDescriptor rnn_desc;
rnn_desc.setWithDropout(dropout_desc, hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algo, datatype);
return rnn_desc;
}
}; };
//TensorDescriptor list. //TensorDescriptor list.
std::vector<TensorDescriptor> rnn_descriptor_sequence(const Tensor& tensor, IntArrayRef batch_sizes) { std::vector<TensorDescriptor> rnn_descriptor_sequence(const Tensor& tensor, IntArrayRef batch_sizes) {
std::vector<TensorDescriptor> descriptors(batch_sizes.size()); std::vector<TensorDescriptor> descriptors(batch_sizes.size());
size_t i =0; size_t i = 0;
auto batch_tensor_size = tensor.sizes().vec(); auto batch_tensor_size = tensor.sizes().vec();
for (auto batch_size : batch_sizes) { for (auto batch_size : batch_sizes) {
@ -204,6 +243,8 @@ struct RNNParams {
struct RNNDescriptors { struct RNNDescriptors {
RNNDescriptor rnn_desc; RNNDescriptor rnn_desc;
static thread_local DropoutDescriptor dropout_desc;
static thread_local std::unique_ptr<DropoutState> dropout_states;
std::vector<TensorDescriptor> x_descs; std::vector<TensorDescriptor> x_descs;
std::vector<TensorDescriptor> y_descs; std::vector<TensorDescriptor> y_descs;
TensorDescriptor hx_desc; TensorDescriptor hx_desc;
@ -212,7 +253,39 @@ struct RNNDescriptors {
TensorDescriptor cy_desc; TensorDescriptor cy_desc;
RNNDescriptors(const RNNParams& fn, miopenHandle_t handle, Tensor x, Tensor y, Tensor hx, Tensor cx) { RNNDescriptors(const RNNParams& fn, miopenHandle_t handle, Tensor x, Tensor y, Tensor hx, Tensor cx) {
rnn_desc = fn.rnn.descriptor(); if (fn.rnn.dropout_rate == 0.0) {
rnn_desc = fn.rnn.descriptor();
} else {
if (!dropout_states) {
size_t states_size_in_bytes = 0;
MIOPEN_CHECK(miopenDropoutGetStatesSize(handle, &states_size_in_bytes));
size_t states_size = states_size_in_bytes / sizeof(rocrand_state_xorwow);
dropout_states = std::make_unique<DropoutState>(states_size * sizeof(rocrand_state_xorwow));
dropout_desc.set(handle,
fn.rnn.dropout_rate,
dropout_states->data,
dropout_states->size,
fn.rnn.dropout_seed,
false,
false,
miopenRNGType_t::MIOPEN_RNG_PSEUDO_XORWOW);
} else {
dropout_desc.restore(handle,
fn.rnn.dropout_rate,
dropout_states->data,
dropout_states->size,
fn.rnn.dropout_seed,
// use_mask flag must be true in order to continue from a saved RNG state
true,
false,
miopenRNGType_t::MIOPEN_RNG_PSEUDO_XORWOW);
}
rnn_desc = fn.rnn.descriptorWithDropout(dropout_desc);
}
x_descs = fn.tensors.descriptors(x); x_descs = fn.tensors.descriptors(x);
y_descs = fn.tensors.descriptors(y); y_descs = fn.tensors.descriptors(y);
hx_desc.set(hx, 5); hx_desc.set(hx, 5);
@ -239,6 +312,11 @@ struct RNNDescriptors {
} }
}; };
// We need to store both the dropout descriptor and state thread locally to avoid multithreading issues
thread_local DropoutDescriptor RNNDescriptors::dropout_desc {};
// Each state is 0.75 MB so there is no problem in caching all of them for each thread
thread_local std::unique_ptr<DropoutState> RNNDescriptors::dropout_states { nullptr };
Tensor permute_wei_for_miopen(Tensor wei, int64_t mode) Tensor permute_wei_for_miopen(Tensor wei, int64_t mode)
{ {
if (mode < 2) if (mode < 2)
@ -492,7 +570,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> miopen_rnn(
auto handle = getMiopenHandle(); auto handle = getMiopenHandle();
miopenRNNAlgo_t algo = miopenRNNdefault; miopenRNNAlgo_t algo = miopenRNNdefault;
fn.rnn.set_algo(algo); fn.rnn.set_algo(algo);
fn.rnn.set_dropout(fn_dropout);
RNNDescriptors descs(fn, handle, x, y, hx, cx); RNNDescriptors descs(fn, handle, x, y, hx, cx);
FilterDescriptor w_desc; FilterDescriptor w_desc;
@ -551,7 +629,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> miopen_rnn(
} }
return std::make_tuple(output, hy, cy, reserve, weight_buf); return std::make_tuple(output, hy, cy, reserve, weight_buf);
} }
std::tuple<Tensor, Tensor, Tensor, Tensor> miopen_rnn_backward_input( std::tuple<Tensor, Tensor, Tensor, Tensor> miopen_rnn_backward_input(
@ -626,6 +703,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> miopen_rnn_backward_input(
miopenRNNAlgo_t algo = miopenRNNdefault; miopenRNNAlgo_t algo = miopenRNNdefault;
fn.rnn.set_algo(algo); fn.rnn.set_algo(algo);
fn.rnn.set_dropout(fn_dropout);
RNNDescriptors descs(fn, handle, x, y, hx, cx); RNNDescriptors descs(fn, handle, x, y, hx, cx);
FilterDescriptor w_desc; FilterDescriptor w_desc;
@ -720,6 +798,7 @@ std::vector<Tensor> miopen_rnn_backward_weight(
miopenRNNAlgo_t algo = miopenRNNdefault; miopenRNNAlgo_t algo = miopenRNNdefault;
fn.rnn.set_algo(algo); fn.rnn.set_algo(algo);
fn.rnn.set_dropout(fn_dropout);
RNNDescriptors descs(fn, handle, x, y, hx, cx); RNNDescriptors descs(fn, handle, x, y, hx, cx);
FilterDescriptor w_desc; FilterDescriptor w_desc;
@ -909,6 +988,6 @@ REGISTER_CUDA_DISPATCH(lstm_miopen_stub, &lstm_miopen)
REGISTER_CUDA_DISPATCH(lstm_packed_miopen_stub, &lstm_packed_miopen) REGISTER_CUDA_DISPATCH(lstm_packed_miopen_stub, &lstm_packed_miopen)
} // anonymous namespace } // anonymous namespace
}} //namespace native. } // namespace at::native
#endif #endif