[SE] Add cudnnTransformTensor to StreamExecutor.

PiperOrigin-RevId: 158062553
This commit is contained in:
Jingyue Wu 2017-06-05 14:30:41 -07:00 committed by TensorFlower Gardener
parent 827874c307
commit 9e6899720a
5 changed files with 71 additions and 0 deletions

View File

@ -2401,6 +2401,31 @@ DeviceMemory<T> CudnnSupport::MaybeTransformLayout(
return (*transform_scratch)->device_memory(); return (*transform_scratch)->device_memory();
} }
bool CudnnSupport::DoTransformTensor(Stream* stream,
const dnn::BatchDescriptor& input_desc,
const DeviceMemory<float>& input_data,
const dnn::BatchDescriptor& output_desc,
DeviceMemory<float>* output_data) {
mutex_lock lock{dnn_handle_mutex_};
float alpha = 1.0f;
float beta = 0.0f;
ScopedTensorDescriptor input_tensor_desc(parent_, input_desc,
CUDNN_DATA_FLOAT);
ScopedTensorDescriptor output_tensor_desc(parent_, output_desc,
CUDNN_DATA_FLOAT);
cudnnStatus_t status = wrap::cudnnTransformTensor(
parent_, ToHandle(dnn_handle_), &alpha, input_tensor_desc.handle(),
input_data.opaque(), &beta, output_tensor_desc.handle(),
output_data->opaque());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "Could not transform a tensor from layout "
<< input_desc.ToShortString() << " to "
<< output_desc.ToShortString();
return false;
}
return true;
}
template <class T> template <class T>
bool CudnnSupport::DoConvolveBackwardDataImpl( bool CudnnSupport::DoConvolveBackwardDataImpl(
Stream* stream, Stream* stream,

View File

@ -481,6 +481,11 @@ class CudnnSupport : public dnn::DnnSupport {
std::unique_ptr<TemporaryDeviceMemory<T>>* transform_scratch) std::unique_ptr<TemporaryDeviceMemory<T>>* transform_scratch)
EXCLUSIVE_LOCKS_REQUIRED(dnn_handle_mutex_); EXCLUSIVE_LOCKS_REQUIRED(dnn_handle_mutex_);
bool DoTransformTensor(Stream* stream, const dnn::BatchDescriptor& input_desc,
const DeviceMemory<float>& input_data,
const dnn::BatchDescriptor& output_desc,
DeviceMemory<float>* output_data) override;
template <class T> template <class T>
bool DoBatchNormalizationForwardImpl( bool DoBatchNormalizationForwardImpl(
Stream* stream, dnn::DataType data_type, const DeviceMemory<T>& x, Stream* stream, dnn::DataType data_type, const DeviceMemory<T>& x,

View File

@ -1960,6 +1960,23 @@ class DnnSupport {
return false; return false;
} }
// Transforms a tensor into another tensor with a different layout and/or data
// type.
//
// Arguments:
// stream: pointer to the stream where this operation should be enqueued to.
// input_desc: descriptor for the input tensor.
// input_data: the device memory region that contains the input tensor.
// output_desc: descriptor for the output tensor.
// output_data: the device memory region that contains the output tensor.
virtual bool DoTransformTensor(Stream* stream,
const dnn::BatchDescriptor& input_desc,
const DeviceMemory<float>& input_data,
const dnn::BatchDescriptor& output_desc,
DeviceMemory<float>* output_data) {
return false;
}
private: private:
SE_DISALLOW_COPY_AND_ASSIGN(DnnSupport); SE_DISALLOW_COPY_AND_ASSIGN(DnnSupport);
}; };

View File

@ -4389,6 +4389,23 @@ Stream &Stream::ThenRnnBackward(
return *this; return *this;
} }
Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
const DeviceMemory<float> &input_data,
const dnn::BatchDescriptor &output_desc,
DeviceMemory<float> *output_data) {
VLOG_CALL(PARAM(input_desc), PARAM(input_data), PARAM(output_desc),
PARAM(output_data));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoTransformTensor(this, input_desc, input_data,
output_desc, output_data));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenDoHostCallbackForTest(std::function<void()> callback) { Stream &Stream::ThenDoHostCallbackForTest(std::function<void()> callback) {
VLOG_CALL(PARAM(callback)); VLOG_CALL(PARAM(callback));

View File

@ -1653,6 +1653,13 @@ class Stream {
DeviceMemory<uint8> *reserve_space_data, DeviceMemory<uint8> *reserve_space_data,
ScratchAllocator *workspace_allocator); ScratchAllocator *workspace_allocator);
// Enqueue onto the stream a operation that transforms a tensor.
// See DnnSupport::DoTransformTensor for more details.
Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
const DeviceMemory<float> &input_data,
const dnn::BatchDescriptor &output_desc,
DeviceMemory<float> *output_data);
// (Synchronously) block the host code waiting for the operations // (Synchronously) block the host code waiting for the operations
// entrained on the stream (enqueued to this point in program // entrained on the stream (enqueued to this point in program
// execution) to complete. // execution) to complete.