mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm] Implemented dropout usage for RNN with MIOpen backend (#144572)
This PR fixes https://github.com/pytorch/pytorch/issues/107183 for ROCm. Implemented the usage of new RNN descriptor for MIOpen backend that takes into account dropout rate value using dropout descriptor. This fixes associated test_RNN_dropout_state test. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144572 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com> Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
parent
2c5c793085
commit
1aa971a3bb
|
|
@ -121,6 +121,21 @@ struct ConvolutionDescriptor
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct DropoutDescriptor
|
||||||
|
: public Descriptor<miopenDropoutDescriptor,
|
||||||
|
&miopenCreateDropoutDescriptor,
|
||||||
|
&miopenDestroyDropoutDescriptor>
|
||||||
|
{
|
||||||
|
void set(miopenHandle_t handle, float dropout, void* states, size_t stateSizeInBytes,
|
||||||
|
unsigned long long seed, bool use_mask, bool state_evo, miopenRNGType_t rng_mode) {
|
||||||
|
MIOPEN_CHECK(miopenSetDropoutDescriptor(mut_desc(), handle, dropout, states, stateSizeInBytes, seed, use_mask, state_evo, rng_mode));
|
||||||
|
}
|
||||||
|
|
||||||
|
void restore(miopenHandle_t handle, float dropout, void* states, size_t stateSizeInBytes,
|
||||||
|
unsigned long long seed, bool use_mask, bool state_evo, miopenRNGType_t rng_mode) {
|
||||||
|
MIOPEN_CHECK(miopenRestoreDropoutDescriptor(mut_desc(), handle, dropout, states, stateSizeInBytes, seed, use_mask, state_evo, rng_mode));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct RNNDescriptor
|
struct RNNDescriptor
|
||||||
: public Descriptor<miopenRNNDescriptor,
|
: public Descriptor<miopenRNNDescriptor,
|
||||||
|
|
@ -128,9 +143,14 @@ struct RNNDescriptor
|
||||||
&miopenDestroyRNNDescriptor>
|
&miopenDestroyRNNDescriptor>
|
||||||
{
|
{
|
||||||
void set(int64_t hidden_size, int64_t num_layers, miopenRNNInputMode_t input_mode, miopenRNNDirectionMode_t direction, miopenRNNMode_t rnn_mode,
|
void set(int64_t hidden_size, int64_t num_layers, miopenRNNInputMode_t input_mode, miopenRNNDirectionMode_t direction, miopenRNNMode_t rnn_mode,
|
||||||
miopenRNNBiasMode_t bias_mode, miopenRNNAlgo_t algorithm, miopenDataType_t datatype) {
|
miopenRNNBiasMode_t bias_mode, miopenRNNAlgo_t algorithm, miopenDataType_t datatype) {
|
||||||
MIOPEN_CHECK(miopenSetRNNDescriptor(mut_desc(), hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algorithm, datatype));
|
MIOPEN_CHECK(miopenSetRNNDescriptor(mut_desc(), hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algorithm, datatype));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void setWithDropout(DropoutDescriptor& dropout_desc, int64_t hidden_size, int64_t num_layers, miopenRNNInputMode_t input_mode, miopenRNNDirectionMode_t direction,
|
||||||
|
miopenRNNMode_t rnn_mode, miopenRNNBiasMode_t bias_mode, miopenRNNAlgo_t algorithm, miopenDataType_t datatype) {
|
||||||
|
MIOPEN_CHECK(miopenSetRNNDescriptor_V2(mut_desc(), hidden_size, num_layers, dropout_desc.mut_desc(), input_mode, direction, rnn_mode, bias_mode, algorithm, datatype));
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
union Constant
|
union Constant
|
||||||
|
|
|
||||||
|
|
@ -57,6 +57,10 @@ namespace at::native {
|
||||||
|
|
||||||
#include <ATen/TensorUtils.h>
|
#include <ATen/TensorUtils.h>
|
||||||
|
|
||||||
|
#include <c10/hip/HIPCachingAllocator.h>
|
||||||
|
|
||||||
|
#include <rocrand/rocrand_xorwow.h>
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
@ -66,12 +70,35 @@ namespace at::native {
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
namespace at { namespace native {
|
namespace at::native {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct DropoutState {
|
||||||
|
DropoutState(size_t size) : size(size), data(NULL) {
|
||||||
|
data = c10::hip::HIPCachingAllocator::raw_alloc(size);
|
||||||
|
}
|
||||||
|
DropoutState(const DropoutState&) = delete;
|
||||||
|
DropoutState(DropoutState&&) = default;
|
||||||
|
DropoutState& operator=(DropoutState&&) = default;
|
||||||
|
~DropoutState() {
|
||||||
|
if (data) {
|
||||||
|
c10::hip::HIPCachingAllocator::raw_delete(data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t size;
|
||||||
|
void* data;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // anonymous
|
||||||
|
|
||||||
//RNNDescriptor.
|
//RNNDescriptor.
|
||||||
struct RNNDescriptorParams {
|
struct RNNDescriptorParams {
|
||||||
int64_t hidden_size;
|
int64_t hidden_size;
|
||||||
int64_t num_layers;
|
int64_t num_layers;
|
||||||
|
double dropout_rate;
|
||||||
|
uint64_t dropout_seed;
|
||||||
miopenRNNDirectionMode_t direction;
|
miopenRNNDirectionMode_t direction;
|
||||||
miopenRNNMode_t rnn_mode;
|
miopenRNNMode_t rnn_mode;
|
||||||
miopenDataType_t datatype;
|
miopenDataType_t datatype;
|
||||||
|
|
@ -114,6 +141,12 @@ struct RNNDescriptorParams {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void set_dropout(double dropout_rate, uint64_t dropout_seed = 0) {
|
||||||
|
this->dropout_rate = dropout_rate;
|
||||||
|
// TODO: Implement seed setting for RNN dropout
|
||||||
|
this->dropout_seed = dropout_seed;
|
||||||
|
}
|
||||||
|
|
||||||
void set(int64_t mode, int64_t hidden_size, int64_t num_layers, bool bidirectional, miopenDataType_t datatype, miopenRNNBiasMode_t bias_mode) {
|
void set(int64_t mode, int64_t hidden_size, int64_t num_layers, bool bidirectional, miopenDataType_t datatype, miopenRNNBiasMode_t bias_mode) {
|
||||||
this->set_mode(mode);
|
this->set_mode(mode);
|
||||||
this->hidden_size = hidden_size;
|
this->hidden_size = hidden_size;
|
||||||
|
|
@ -128,12 +161,18 @@ struct RNNDescriptorParams {
|
||||||
rnn_desc.set(hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algo, datatype);
|
rnn_desc.set(hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algo, datatype);
|
||||||
return rnn_desc;
|
return rnn_desc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
RNNDescriptor descriptorWithDropout(DropoutDescriptor& dropout_desc) const {
|
||||||
|
RNNDescriptor rnn_desc;
|
||||||
|
rnn_desc.setWithDropout(dropout_desc, hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algo, datatype);
|
||||||
|
return rnn_desc;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
//TensorDescriptor list.
|
//TensorDescriptor list.
|
||||||
std::vector<TensorDescriptor> rnn_descriptor_sequence(const Tensor& tensor, IntArrayRef batch_sizes) {
|
std::vector<TensorDescriptor> rnn_descriptor_sequence(const Tensor& tensor, IntArrayRef batch_sizes) {
|
||||||
std::vector<TensorDescriptor> descriptors(batch_sizes.size());
|
std::vector<TensorDescriptor> descriptors(batch_sizes.size());
|
||||||
size_t i =0;
|
size_t i = 0;
|
||||||
|
|
||||||
auto batch_tensor_size = tensor.sizes().vec();
|
auto batch_tensor_size = tensor.sizes().vec();
|
||||||
for (auto batch_size : batch_sizes) {
|
for (auto batch_size : batch_sizes) {
|
||||||
|
|
@ -204,6 +243,8 @@ struct RNNParams {
|
||||||
|
|
||||||
struct RNNDescriptors {
|
struct RNNDescriptors {
|
||||||
RNNDescriptor rnn_desc;
|
RNNDescriptor rnn_desc;
|
||||||
|
static thread_local DropoutDescriptor dropout_desc;
|
||||||
|
static thread_local std::unique_ptr<DropoutState> dropout_states;
|
||||||
std::vector<TensorDescriptor> x_descs;
|
std::vector<TensorDescriptor> x_descs;
|
||||||
std::vector<TensorDescriptor> y_descs;
|
std::vector<TensorDescriptor> y_descs;
|
||||||
TensorDescriptor hx_desc;
|
TensorDescriptor hx_desc;
|
||||||
|
|
@ -212,7 +253,39 @@ struct RNNDescriptors {
|
||||||
TensorDescriptor cy_desc;
|
TensorDescriptor cy_desc;
|
||||||
|
|
||||||
RNNDescriptors(const RNNParams& fn, miopenHandle_t handle, Tensor x, Tensor y, Tensor hx, Tensor cx) {
|
RNNDescriptors(const RNNParams& fn, miopenHandle_t handle, Tensor x, Tensor y, Tensor hx, Tensor cx) {
|
||||||
rnn_desc = fn.rnn.descriptor();
|
if (fn.rnn.dropout_rate == 0.0) {
|
||||||
|
rnn_desc = fn.rnn.descriptor();
|
||||||
|
} else {
|
||||||
|
if (!dropout_states) {
|
||||||
|
size_t states_size_in_bytes = 0;
|
||||||
|
MIOPEN_CHECK(miopenDropoutGetStatesSize(handle, &states_size_in_bytes));
|
||||||
|
size_t states_size = states_size_in_bytes / sizeof(rocrand_state_xorwow);
|
||||||
|
|
||||||
|
dropout_states = std::make_unique<DropoutState>(states_size * sizeof(rocrand_state_xorwow));
|
||||||
|
|
||||||
|
dropout_desc.set(handle,
|
||||||
|
fn.rnn.dropout_rate,
|
||||||
|
dropout_states->data,
|
||||||
|
dropout_states->size,
|
||||||
|
fn.rnn.dropout_seed,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
miopenRNGType_t::MIOPEN_RNG_PSEUDO_XORWOW);
|
||||||
|
} else {
|
||||||
|
dropout_desc.restore(handle,
|
||||||
|
fn.rnn.dropout_rate,
|
||||||
|
dropout_states->data,
|
||||||
|
dropout_states->size,
|
||||||
|
fn.rnn.dropout_seed,
|
||||||
|
// use_mask flag must be true in order to continue from a saved RNG state
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
miopenRNGType_t::MIOPEN_RNG_PSEUDO_XORWOW);
|
||||||
|
}
|
||||||
|
|
||||||
|
rnn_desc = fn.rnn.descriptorWithDropout(dropout_desc);
|
||||||
|
}
|
||||||
|
|
||||||
x_descs = fn.tensors.descriptors(x);
|
x_descs = fn.tensors.descriptors(x);
|
||||||
y_descs = fn.tensors.descriptors(y);
|
y_descs = fn.tensors.descriptors(y);
|
||||||
hx_desc.set(hx, 5);
|
hx_desc.set(hx, 5);
|
||||||
|
|
@ -239,6 +312,11 @@ struct RNNDescriptors {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// We need to store both the dropout descriptor and state thread locally to avoid multithreading issues
|
||||||
|
thread_local DropoutDescriptor RNNDescriptors::dropout_desc {};
|
||||||
|
// Each state is 0.75 MB so there is no problem in caching all of them for each thread
|
||||||
|
thread_local std::unique_ptr<DropoutState> RNNDescriptors::dropout_states { nullptr };
|
||||||
|
|
||||||
Tensor permute_wei_for_miopen(Tensor wei, int64_t mode)
|
Tensor permute_wei_for_miopen(Tensor wei, int64_t mode)
|
||||||
{
|
{
|
||||||
if (mode < 2)
|
if (mode < 2)
|
||||||
|
|
@ -492,7 +570,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> miopen_rnn(
|
||||||
auto handle = getMiopenHandle();
|
auto handle = getMiopenHandle();
|
||||||
miopenRNNAlgo_t algo = miopenRNNdefault;
|
miopenRNNAlgo_t algo = miopenRNNdefault;
|
||||||
fn.rnn.set_algo(algo);
|
fn.rnn.set_algo(algo);
|
||||||
|
fn.rnn.set_dropout(fn_dropout);
|
||||||
RNNDescriptors descs(fn, handle, x, y, hx, cx);
|
RNNDescriptors descs(fn, handle, x, y, hx, cx);
|
||||||
|
|
||||||
FilterDescriptor w_desc;
|
FilterDescriptor w_desc;
|
||||||
|
|
@ -551,7 +629,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> miopen_rnn(
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_tuple(output, hy, cy, reserve, weight_buf);
|
return std::make_tuple(output, hy, cy, reserve, weight_buf);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor, Tensor> miopen_rnn_backward_input(
|
std::tuple<Tensor, Tensor, Tensor, Tensor> miopen_rnn_backward_input(
|
||||||
|
|
@ -626,6 +703,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> miopen_rnn_backward_input(
|
||||||
|
|
||||||
miopenRNNAlgo_t algo = miopenRNNdefault;
|
miopenRNNAlgo_t algo = miopenRNNdefault;
|
||||||
fn.rnn.set_algo(algo);
|
fn.rnn.set_algo(algo);
|
||||||
|
fn.rnn.set_dropout(fn_dropout);
|
||||||
RNNDescriptors descs(fn, handle, x, y, hx, cx);
|
RNNDescriptors descs(fn, handle, x, y, hx, cx);
|
||||||
|
|
||||||
FilterDescriptor w_desc;
|
FilterDescriptor w_desc;
|
||||||
|
|
@ -720,6 +798,7 @@ std::vector<Tensor> miopen_rnn_backward_weight(
|
||||||
|
|
||||||
miopenRNNAlgo_t algo = miopenRNNdefault;
|
miopenRNNAlgo_t algo = miopenRNNdefault;
|
||||||
fn.rnn.set_algo(algo);
|
fn.rnn.set_algo(algo);
|
||||||
|
fn.rnn.set_dropout(fn_dropout);
|
||||||
RNNDescriptors descs(fn, handle, x, y, hx, cx);
|
RNNDescriptors descs(fn, handle, x, y, hx, cx);
|
||||||
|
|
||||||
FilterDescriptor w_desc;
|
FilterDescriptor w_desc;
|
||||||
|
|
@ -909,6 +988,6 @@ REGISTER_CUDA_DISPATCH(lstm_miopen_stub, &lstm_miopen)
|
||||||
REGISTER_CUDA_DISPATCH(lstm_packed_miopen_stub, &lstm_packed_miopen)
|
REGISTER_CUDA_DISPATCH(lstm_packed_miopen_stub, &lstm_packed_miopen)
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
}} //namespace native.
|
} // namespace at::native
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user