mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
[SE] Add cudnnTransformTensor to StreamExecutor.
PiperOrigin-RevId: 158062553
This commit is contained in:
parent
827874c307
commit
9e6899720a
|
|
@ -2401,6 +2401,31 @@ DeviceMemory<T> CudnnSupport::MaybeTransformLayout(
|
||||||
return (*transform_scratch)->device_memory();
|
return (*transform_scratch)->device_memory();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool CudnnSupport::DoTransformTensor(Stream* stream,
|
||||||
|
const dnn::BatchDescriptor& input_desc,
|
||||||
|
const DeviceMemory<float>& input_data,
|
||||||
|
const dnn::BatchDescriptor& output_desc,
|
||||||
|
DeviceMemory<float>* output_data) {
|
||||||
|
mutex_lock lock{dnn_handle_mutex_};
|
||||||
|
float alpha = 1.0f;
|
||||||
|
float beta = 0.0f;
|
||||||
|
ScopedTensorDescriptor input_tensor_desc(parent_, input_desc,
|
||||||
|
CUDNN_DATA_FLOAT);
|
||||||
|
ScopedTensorDescriptor output_tensor_desc(parent_, output_desc,
|
||||||
|
CUDNN_DATA_FLOAT);
|
||||||
|
cudnnStatus_t status = wrap::cudnnTransformTensor(
|
||||||
|
parent_, ToHandle(dnn_handle_), &alpha, input_tensor_desc.handle(),
|
||||||
|
input_data.opaque(), &beta, output_tensor_desc.handle(),
|
||||||
|
output_data->opaque());
|
||||||
|
if (status != CUDNN_STATUS_SUCCESS) {
|
||||||
|
LOG(ERROR) << "Could not transform a tensor from layout "
|
||||||
|
<< input_desc.ToShortString() << " to "
|
||||||
|
<< output_desc.ToShortString();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
bool CudnnSupport::DoConvolveBackwardDataImpl(
|
bool CudnnSupport::DoConvolveBackwardDataImpl(
|
||||||
Stream* stream,
|
Stream* stream,
|
||||||
|
|
|
||||||
|
|
@ -481,6 +481,11 @@ class CudnnSupport : public dnn::DnnSupport {
|
||||||
std::unique_ptr<TemporaryDeviceMemory<T>>* transform_scratch)
|
std::unique_ptr<TemporaryDeviceMemory<T>>* transform_scratch)
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(dnn_handle_mutex_);
|
EXCLUSIVE_LOCKS_REQUIRED(dnn_handle_mutex_);
|
||||||
|
|
||||||
|
bool DoTransformTensor(Stream* stream, const dnn::BatchDescriptor& input_desc,
|
||||||
|
const DeviceMemory<float>& input_data,
|
||||||
|
const dnn::BatchDescriptor& output_desc,
|
||||||
|
DeviceMemory<float>* output_data) override;
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
bool DoBatchNormalizationForwardImpl(
|
bool DoBatchNormalizationForwardImpl(
|
||||||
Stream* stream, dnn::DataType data_type, const DeviceMemory<T>& x,
|
Stream* stream, dnn::DataType data_type, const DeviceMemory<T>& x,
|
||||||
|
|
|
||||||
|
|
@ -1960,6 +1960,23 @@ class DnnSupport {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Transforms a tensor into another tensor with a different layout and/or data
|
||||||
|
// type.
|
||||||
|
//
|
||||||
|
// Arguments:
|
||||||
|
// stream: pointer to the stream where this operation should be enqueued to.
|
||||||
|
// input_desc: descriptor for the input tensor.
|
||||||
|
// input_data: the device memory region that contains the input tensor.
|
||||||
|
// output_desc: descriptor for the output tensor.
|
||||||
|
// output_data: the device memory region that contains the output tensor.
|
||||||
|
virtual bool DoTransformTensor(Stream* stream,
|
||||||
|
const dnn::BatchDescriptor& input_desc,
|
||||||
|
const DeviceMemory<float>& input_data,
|
||||||
|
const dnn::BatchDescriptor& output_desc,
|
||||||
|
DeviceMemory<float>* output_data) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
SE_DISALLOW_COPY_AND_ASSIGN(DnnSupport);
|
SE_DISALLOW_COPY_AND_ASSIGN(DnnSupport);
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -4389,6 +4389,23 @@ Stream &Stream::ThenRnnBackward(
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
|
||||||
|
const DeviceMemory<float> &input_data,
|
||||||
|
const dnn::BatchDescriptor &output_desc,
|
||||||
|
DeviceMemory<float> *output_data) {
|
||||||
|
VLOG_CALL(PARAM(input_desc), PARAM(input_data), PARAM(output_desc),
|
||||||
|
PARAM(output_data));
|
||||||
|
if (ok()) {
|
||||||
|
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
|
||||||
|
CheckError(dnn->DoTransformTensor(this, input_desc, input_data,
|
||||||
|
output_desc, output_data));
|
||||||
|
} else {
|
||||||
|
SetErrorAndLogNoDnnSupport();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
Stream &Stream::ThenDoHostCallbackForTest(std::function<void()> callback) {
|
Stream &Stream::ThenDoHostCallbackForTest(std::function<void()> callback) {
|
||||||
VLOG_CALL(PARAM(callback));
|
VLOG_CALL(PARAM(callback));
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1653,6 +1653,13 @@ class Stream {
|
||||||
DeviceMemory<uint8> *reserve_space_data,
|
DeviceMemory<uint8> *reserve_space_data,
|
||||||
ScratchAllocator *workspace_allocator);
|
ScratchAllocator *workspace_allocator);
|
||||||
|
|
||||||
|
// Enqueue onto the stream a operation that transforms a tensor.
|
||||||
|
// See DnnSupport::DoTransformTensor for more details.
|
||||||
|
Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
|
||||||
|
const DeviceMemory<float> &input_data,
|
||||||
|
const dnn::BatchDescriptor &output_desc,
|
||||||
|
DeviceMemory<float> *output_data);
|
||||||
|
|
||||||
// (Synchronously) block the host code waiting for the operations
|
// (Synchronously) block the host code waiting for the operations
|
||||||
// entrained on the stream (enqueued to this point in program
|
// entrained on the stream (enqueued to this point in program
|
||||||
// execution) to complete.
|
// execution) to complete.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user