mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
[SE] Add cudnnTransformTensor to StreamExecutor.
PiperOrigin-RevId: 158062553
This commit is contained in:
parent
827874c307
commit
9e6899720a
|
|
@ -2401,6 +2401,31 @@ DeviceMemory<T> CudnnSupport::MaybeTransformLayout(
|
|||
return (*transform_scratch)->device_memory();
|
||||
}
|
||||
|
||||
bool CudnnSupport::DoTransformTensor(Stream* stream,
|
||||
const dnn::BatchDescriptor& input_desc,
|
||||
const DeviceMemory<float>& input_data,
|
||||
const dnn::BatchDescriptor& output_desc,
|
||||
DeviceMemory<float>* output_data) {
|
||||
mutex_lock lock{dnn_handle_mutex_};
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
ScopedTensorDescriptor input_tensor_desc(parent_, input_desc,
|
||||
CUDNN_DATA_FLOAT);
|
||||
ScopedTensorDescriptor output_tensor_desc(parent_, output_desc,
|
||||
CUDNN_DATA_FLOAT);
|
||||
cudnnStatus_t status = wrap::cudnnTransformTensor(
|
||||
parent_, ToHandle(dnn_handle_), &alpha, input_tensor_desc.handle(),
|
||||
input_data.opaque(), &beta, output_tensor_desc.handle(),
|
||||
output_data->opaque());
|
||||
if (status != CUDNN_STATUS_SUCCESS) {
|
||||
LOG(ERROR) << "Could not transform a tensor from layout "
|
||||
<< input_desc.ToShortString() << " to "
|
||||
<< output_desc.ToShortString();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool CudnnSupport::DoConvolveBackwardDataImpl(
|
||||
Stream* stream,
|
||||
|
|
|
|||
|
|
@ -481,6 +481,11 @@ class CudnnSupport : public dnn::DnnSupport {
|
|||
std::unique_ptr<TemporaryDeviceMemory<T>>* transform_scratch)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(dnn_handle_mutex_);
|
||||
|
||||
bool DoTransformTensor(Stream* stream, const dnn::BatchDescriptor& input_desc,
|
||||
const DeviceMemory<float>& input_data,
|
||||
const dnn::BatchDescriptor& output_desc,
|
||||
DeviceMemory<float>* output_data) override;
|
||||
|
||||
template <class T>
|
||||
bool DoBatchNormalizationForwardImpl(
|
||||
Stream* stream, dnn::DataType data_type, const DeviceMemory<T>& x,
|
||||
|
|
|
|||
|
|
@ -1960,6 +1960,23 @@ class DnnSupport {
|
|||
return false;
|
||||
}
|
||||
|
||||
// Transforms a tensor into another tensor with a different layout and/or data
|
||||
// type.
|
||||
//
|
||||
// Arguments:
|
||||
// stream: pointer to the stream where this operation should be enqueued to.
|
||||
// input_desc: descriptor for the input tensor.
|
||||
// input_data: the device memory region that contains the input tensor.
|
||||
// output_desc: descriptor for the output tensor.
|
||||
// output_data: the device memory region that contains the output tensor.
|
||||
virtual bool DoTransformTensor(Stream* stream,
|
||||
const dnn::BatchDescriptor& input_desc,
|
||||
const DeviceMemory<float>& input_data,
|
||||
const dnn::BatchDescriptor& output_desc,
|
||||
DeviceMemory<float>* output_data) {
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(DnnSupport);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -4389,6 +4389,23 @@ Stream &Stream::ThenRnnBackward(
|
|||
return *this;
|
||||
}
|
||||
|
||||
Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
|
||||
const DeviceMemory<float> &input_data,
|
||||
const dnn::BatchDescriptor &output_desc,
|
||||
DeviceMemory<float> *output_data) {
|
||||
VLOG_CALL(PARAM(input_desc), PARAM(input_data), PARAM(output_desc),
|
||||
PARAM(output_data));
|
||||
if (ok()) {
|
||||
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
|
||||
CheckError(dnn->DoTransformTensor(this, input_desc, input_data,
|
||||
output_desc, output_data));
|
||||
} else {
|
||||
SetErrorAndLogNoDnnSupport();
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
Stream &Stream::ThenDoHostCallbackForTest(std::function<void()> callback) {
|
||||
VLOG_CALL(PARAM(callback));
|
||||
|
||||
|
|
|
|||
|
|
@ -1653,6 +1653,13 @@ class Stream {
|
|||
DeviceMemory<uint8> *reserve_space_data,
|
||||
ScratchAllocator *workspace_allocator);
|
||||
|
||||
// Enqueue onto the stream a operation that transforms a tensor.
|
||||
// See DnnSupport::DoTransformTensor for more details.
|
||||
Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
|
||||
const DeviceMemory<float> &input_data,
|
||||
const dnn::BatchDescriptor &output_desc,
|
||||
DeviceMemory<float> *output_data);
|
||||
|
||||
// (Synchronously) block the host code waiting for the operations
|
||||
// entrained on the stream (enqueued to this point in program
|
||||
// execution) to complete.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user