Set winograd nofused flag to be true by default.

Disable winograd nonfused conv for certain input params to avoid a known bug in cuDNNv5 and cuDNNv6.

PiperOrigin-RevId: 157352847
This commit is contained in:
Yangzihao Wang 2017-05-28 14:20:19 -07:00 committed by TensorFlower Gardener
parent 3f9b69a50f
commit e78e5ec8a8
16 changed files with 135 additions and 34 deletions

View File

@ -258,15 +258,21 @@ tensorflow::Status ConvolutionThunk::Convolve(
std::vector<se::dnn::AlgorithmType> ConvolutionThunk::GetAlgorithms( std::vector<se::dnn::AlgorithmType> ConvolutionThunk::GetAlgorithms(
se::StreamExecutor* stream_exec) const { se::StreamExecutor* stream_exec) const {
std::vector<se::dnn::AlgorithmType> algorithms; std::vector<se::dnn::AlgorithmType> algorithms;
// TODO(yangzihao): Currently disable the use of winograd nonfused in XLA
// by default. Should send in conv parameters and enable it when
// ShouldIncludeWinogradNonfusedAlgo() returns true.
switch (convolution_kind_) { switch (convolution_kind_) {
case ConvolutionKind::kBackwardFilter: case ConvolutionKind::kBackwardFilter:
CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms(&algorithms)); CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms(
/*with_winograd_nonfused=*/false, &algorithms));
break; break;
case ConvolutionKind::kBackwardInput: case ConvolutionKind::kBackwardInput:
CHECK(stream_exec->GetConvolveBackwardDataAlgorithms(&algorithms)); CHECK(stream_exec->GetConvolveBackwardDataAlgorithms(
/*with_winograd_nonfused=*/false, &algorithms));
break; break;
case ConvolutionKind::kForward: case ConvolutionKind::kForward:
CHECK(stream_exec->GetConvolveAlgorithms(&algorithms)); CHECK(stream_exec->GetConvolveAlgorithms(/*with_winograd_nonfused=*/false,
&algorithms));
break; break;
} }
return algorithms; return algorithms;

View File

@ -776,7 +776,8 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
if (cudnn_use_autotune_ && !AutoTuneConvBwdFilter::GetInstance()->Find( if (cudnn_use_autotune_ && !AutoTuneConvBwdFilter::GetInstance()->Find(
conv_parameters, &algorithm_config)) { conv_parameters, &algorithm_config)) {
std::vector<AlgorithmType> algorithms; std::vector<AlgorithmType> algorithms;
CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(&algorithms)); CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
ProfileResult best_result; ProfileResult best_result;
ProfileResult best_result_no_scratch; ProfileResult best_result_no_scratch;
for (auto profile_algorithm : algorithms) { for (auto profile_algorithm : algorithms) {

View File

@ -856,7 +856,8 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
if (cudnn_use_autotune_ && !AutoTuneConvBwdData::GetInstance()->Find( if (cudnn_use_autotune_ && !AutoTuneConvBwdData::GetInstance()->Find(
conv_parameters, &algorithm_config)) { conv_parameters, &algorithm_config)) {
std::vector<AlgorithmType> algorithms; std::vector<AlgorithmType> algorithms;
CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(&algorithms)); CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
ProfileResult best_result; ProfileResult best_result;
ProfileResult best_result_no_scratch; ProfileResult best_result_no_scratch;
for (auto profile_algorithm : algorithms) { for (auto profile_algorithm : algorithms) {

View File

@ -656,7 +656,8 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
if (cudnn_use_autotune_ && !AutoTuneConv3dBwdData::GetInstance()->Find( if (cudnn_use_autotune_ && !AutoTuneConv3dBwdData::GetInstance()->Find(
conv_parameters, &algorithm_config)) { conv_parameters, &algorithm_config)) {
std::vector<AlgorithmType> algorithms; std::vector<AlgorithmType> algorithms;
CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(&algorithms)); CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
ProfileResult best_result; ProfileResult best_result;
ProfileResult best_result_no_scratch; ProfileResult best_result_no_scratch;
for (auto profile_algorithm : algorithms) { for (auto profile_algorithm : algorithms) {
@ -1020,11 +1021,11 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
using perftools::gputools::dnn::ProfileResult; using perftools::gputools::dnn::ProfileResult;
using perftools::gputools::dnn::kDefaultAlgorithm; using perftools::gputools::dnn::kDefaultAlgorithm;
AlgorithmConfig algorithm_config; AlgorithmConfig algorithm_config;
if (cudnn_use_autotune_ && !AutoTuneConv3dBwdFilter::GetInstance()->Find( if (cudnn_use_autotune_ && !AutoTuneConv3dBwdFilter::GetInstance()->Find(
conv_parameters, &algorithm_config)) { conv_parameters, &algorithm_config)) {
std::vector<AlgorithmType> algorithms; std::vector<AlgorithmType> algorithms;
CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(&algorithms)); CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
ProfileResult best_result; ProfileResult best_result;
ProfileResult best_result_no_scratch; ProfileResult best_result_no_scratch;
for (auto profile_algorithm : algorithms) { for (auto profile_algorithm : algorithms) {

View File

@ -668,7 +668,8 @@ void LaunchConv2DOp<GPUDevice, T>::launch(
if (cudnn_use_autotune && if (cudnn_use_autotune &&
!AutoTuneConv::GetInstance()->Find(conv_parameters, &algorithm_config)) { !AutoTuneConv::GetInstance()->Find(conv_parameters, &algorithm_config)) {
std::vector<AlgorithmType> algorithms; std::vector<AlgorithmType> algorithms;
CHECK(stream->parent()->GetConvolveAlgorithms(&algorithms)); CHECK(stream->parent()->GetConvolveAlgorithms(
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
ProfileResult best_result; ProfileResult best_result;
ProfileResult best_result_no_scratch; ProfileResult best_result_no_scratch;
for (auto profile_algorithm : algorithms) { for (auto profile_algorithm : algorithms) {

View File

@ -392,7 +392,8 @@ struct LaunchConvOp<GPUDevice, T> {
if (cudnn_use_autotune && !AutoTuneConv3d::GetInstance()->Find( if (cudnn_use_autotune && !AutoTuneConv3d::GetInstance()->Find(
conv_parameters, &algorithm_config)) { conv_parameters, &algorithm_config)) {
std::vector<AlgorithmType> algorithms; std::vector<AlgorithmType> algorithms;
CHECK(stream->parent()->GetConvolveAlgorithms(&algorithms)); CHECK(stream->parent()->GetConvolveAlgorithms(
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
ProfileResult best_result; ProfileResult best_result;
ProfileResult best_result_no_scratch; ProfileResult best_result_no_scratch;
for (auto profile_algorithm : algorithms) { for (auto profile_algorithm : algorithms) {

View File

@ -145,6 +145,22 @@ class ConvParameters {
// clang-format on // clang-format on
} }
// TODO(yangzihao): The purpose of this function is to disable winograd
// nonfused conv algorithm for certain input parameters so as to avoid a bug
// in cuDNNv5 and cuDNNv6. Remove this once switch to cuDNNv7.
template <typename T>
bool ShouldIncludeWinogradNonfusedAlgo() const {
int64 total_size = 16 * std::ceil(batch_ / 16.0) *
std::max(in_depths_, out_depths_) * in_[0] * in_[1] *
sizeof(T);
int64 threshold = 1L << 31;
if (total_size >= threshold) {
return false;
} else {
return true;
}
}
private: private:
typedef std::tuple<int64, int64, SpatialArray, int64, SpatialArray, typedef std::tuple<int64, int64, SpatialArray, int64, SpatialArray,
SpatialArray, SpatialArray, DataType, int> SpatialArray, SpatialArray, DataType, int>

View File

@ -28,8 +28,50 @@ limitations under the License.
#include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session.h"
#include "tensorflow/core/kernels/conv_ops_gpu.h"
namespace tensorflow { namespace tensorflow {
#if GOOGLE_CUDA
TEST(ConvParameters, WinogradNonfusedAlgoSize) {
ConvParameters conv_params_small = {
1, // batch
32, // in_depths
{{300, // in_rows
300}}, // in_cols
128, // out_depths
{{3, // filter_rows
3}}, // filter_cols
{{1, // stride_rows
1}}, // stride_cols
{{0, // padding_rows
0}}, // padding_cols
DT_FLOAT, // tensor datatype
0, // device_id
};
EXPECT_TRUE(conv_params_small.ShouldIncludeWinogradNonfusedAlgo<float>());
ConvParameters conv_params_large = {
1, // batch
128, // in_depths
{{300, // in_rows
300}}, // in_cols
768, // out_depths
{{3, // filter_rows
3}}, // filter_cols
{{1, // stride_rows
1}}, // stride_cols
{{0, // padding_rows
0}}, // padding_cols
DT_FLOAT, // tensor datatype
0, // device_id
};
EXPECT_FALSE(conv_params_large.ShouldIncludeWinogradNonfusedAlgo<float>());
}
#endif // GOOGLE_CUDA
class FusedResizePadConvOpTest : public OpsTestBase { class FusedResizePadConvOpTest : public OpsTestBase {
protected: protected:
void HandwrittenConv() { void HandwrittenConv() {

View File

@ -2071,7 +2071,7 @@ cuda_py_test(
cuda_py_test( cuda_py_test(
name = "conv_ops_test", name = "conv_ops_test",
size = "medium", size = "large",
srcs = ["conv_ops_test.py"], srcs = ["conv_ops_test.py"],
additional_deps = [ additional_deps = [
"//third_party/py/numpy", "//third_party/py/numpy",
@ -2089,6 +2089,7 @@ cuda_py_test(
"//tensorflow/python:random_ops", "//tensorflow/python:random_ops",
"//tensorflow/python:variables", "//tensorflow/python:variables",
], ],
shard_count = 4,
) )
cuda_py_test( cuda_py_test(

View File

@ -189,7 +189,7 @@ class Conv2DTest(test.TestCase):
# numbers from 1. # numbers from 1.
x1 = [f * 1.0 for f in range(1, total_size_1 + 1)] x1 = [f * 1.0 for f in range(1, total_size_1 + 1)]
x2 = [f * 1.0 for f in range(1, total_size_2 + 1)] x2 = [f * 1.0 for f in range(1, total_size_2 + 1)]
with self.test_session(use_gpu=use_gpu) as sess: with self.test_session(use_gpu=use_gpu):
t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype) t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype)
t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype) t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype)
strides = [1] + strides + [1] strides = [1] + strides + [1]
@ -424,7 +424,7 @@ class Conv2DTest(test.TestCase):
x2 = np.random.rand(*output_sizes).astype(np.float32) x2 = np.random.rand(*output_sizes).astype(np.float32)
def _GetVal(data_format, use_gpu): def _GetVal(data_format, use_gpu):
with self.test_session(use_gpu=use_gpu) as sess: with self.test_session(use_gpu=use_gpu):
if data_format == "NCHW": if data_format == "NCHW":
new_input_sizes = test_util.NHWCToNCHW(input_sizes) new_input_sizes = test_util.NHWCToNCHW(input_sizes)
else: else:
@ -580,7 +580,7 @@ class Conv2DTest(test.TestCase):
x2 = np.random.rand(*output_sizes).astype(np.float32) x2 = np.random.rand(*output_sizes).astype(np.float32)
def _GetVal(data_format, use_gpu): def _GetVal(data_format, use_gpu):
with self.test_session(use_gpu=use_gpu) as sess: with self.test_session(use_gpu=use_gpu):
t0 = constant_op.constant(x0, shape=input_sizes) t0 = constant_op.constant(x0, shape=input_sizes)
t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)]) t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)])
t2 = constant_op.constant(x2, shape=output_sizes) t2 = constant_op.constant(x2, shape=output_sizes)
@ -1444,4 +1444,19 @@ if __name__ == "__main__":
GetInceptionBackFilterTest(input_size_, filter_size_, output_size_, GetInceptionBackFilterTest(input_size_, filter_size_, output_size_,
[stride_, stride_], padding_)) [stride_, stride_], padding_))
# TODO(b/35359731)
# Fwd, BckInput, and BackFilter to test that for certain input parameter
# set, winograd nonfused algorithm will be excluded from conv autotune. If
# in such case, winograd nonfused algorithm is added as one option of the
# conv autotune, and cuDNN version is smaller than 7, the following tests
# will fail.
ishape = [1, 400, 400, 128]
fshape = [3, 3, 128, 768]
oshape = [1, 400, 400, 768]
setattr(Conv2DTest, "testInceptionFwd_No_Winograd_Nonfused",
GetInceptionFwdTest(ishape, fshape, 1, "SAME"))
setattr(Conv2DTest, "testInceptionBackInput_No_Winograd_Nonfused",
GetInceptionBackInputTest(ishape, fshape, oshape, 1, "SAME"))
setattr(Conv2DTest, "testInceptionBackFilter_No_Winograd_Nonfused",
GetInceptionBackFilterTest(ishape, fshape, oshape, [1, 1], "SAME"))
test.main() test.main()

View File

@ -1966,12 +1966,12 @@ bool CudnnSupport::DoConvolveImpl(
} }
// A helper class to decide whether to enable the WINOGRAD_NONFUSED algorithms. // A helper class to decide whether to enable the WINOGRAD_NONFUSED algorithms.
// Doing so by default make a few TensorFlow test cases to fail. Users can // By default it is turned on, users can explicitly disable them through an
// explicitly enable them through an env-var "TF_ENABLE_WINOGRAD_NONFUSED=1". // env-var "TF_ENABLE_WINOGRAD_NONFUSED=0".
// https://github.com/tensorflow/tensorflow/pull/4901 // https://github.com/tensorflow/tensorflow/pull/4901
// TODO(yangzihao): for certain shapes, setting default flag to be true will // TODO(yangzihao): winograd_nonfused bug will only be fixed in cuDNNv7, for
// cause bug and return negative tensor shapes. Will flip the default flag when // cuDNN with smaller versions, we have added code to avoid using winograd
// the bug is fixed. // nonfused for certain input parameter set.
template <bool DefaultFlag> template <bool DefaultFlag>
class WinogradNonfused { class WinogradNonfused {
public: public:
@ -1997,6 +1997,7 @@ class WinogradNonfused {
}; };
bool CudnnSupport::GetConvolveAlgorithms( bool CudnnSupport::GetConvolveAlgorithms(
bool with_winograd_nonfused,
std::vector<dnn::AlgorithmType>* out_algorithms) { std::vector<dnn::AlgorithmType>* out_algorithms) {
out_algorithms->assign({ out_algorithms->assign({
// clang-format off // clang-format off
@ -2012,7 +2013,7 @@ bool CudnnSupport::GetConvolveAlgorithms(
// clang-format on // clang-format on
}); });
#if CUDNN_VERSION >= 5100 #if CUDNN_VERSION >= 5100
if (WinogradNonfused<false>::IsEnabled()) { if (WinogradNonfused<true>::IsEnabled() && with_winograd_nonfused) {
out_algorithms->push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED); out_algorithms->push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED);
} }
#endif #endif
@ -2020,6 +2021,7 @@ bool CudnnSupport::GetConvolveAlgorithms(
} }
bool CudnnSupport::GetConvolveBackwardDataAlgorithms( bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
bool with_winograd_nonfused,
std::vector<dnn::AlgorithmType>* out_algorithms) { std::vector<dnn::AlgorithmType>* out_algorithms) {
out_algorithms->assign({ out_algorithms->assign({
// clang-format off // clang-format off
@ -2033,7 +2035,7 @@ bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
// clang-format on // clang-format on
}); });
#if CUDNN_VERSION >= 5100 #if CUDNN_VERSION >= 5100
if (WinogradNonfused<false>::IsEnabled()) { if (WinogradNonfused<true>::IsEnabled() && with_winograd_nonfused) {
out_algorithms->push_back( out_algorithms->push_back(
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED); CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED);
} }
@ -2042,6 +2044,7 @@ bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
} }
bool CudnnSupport::GetConvolveBackwardFilterAlgorithms( bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
bool with_winograd_nonfused,
std::vector<dnn::AlgorithmType>* out_algorithms) { std::vector<dnn::AlgorithmType>* out_algorithms) {
out_algorithms->assign({ out_algorithms->assign({
// clang-format off // clang-format off
@ -2053,11 +2056,12 @@ bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
}); });
#if CUDNN_VERSION >= 5100 #if CUDNN_VERSION >= 5100
#if CUDNN_VERSION >= 5110 #if CUDNN_VERSION >= 5110
static constexpr bool kDefaultFlagWinogradNonfused = false; static constexpr bool kDefaultFlagWinogradNonfused = true;
#else #else
static constexpr bool kDefaultFlagWinogradNonfused = false; static constexpr bool kDefaultFlagWinogradNonfused = false;
#endif #endif
if (WinogradNonfused<kDefaultFlagWinogradNonfused>::IsEnabled()) { if (WinogradNonfused<kDefaultFlagWinogradNonfused>::IsEnabled() &&
with_winograd_nonfused) {
out_algorithms->push_back( out_algorithms->push_back(
// Based on cudnn.h, the following is not implemented. // Based on cudnn.h, the following is not implemented.
// CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD, // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD,

View File

@ -104,12 +104,15 @@ class CudnnSupport : public dnn::DnnSupport {
ScratchAllocator* workspace_allocator) override; ScratchAllocator* workspace_allocator) override;
bool GetConvolveAlgorithms( bool GetConvolveAlgorithms(
bool with_winograd_nonfused,
std::vector<dnn::AlgorithmType>* out_algorithms) override; std::vector<dnn::AlgorithmType>* out_algorithms) override;
bool GetConvolveBackwardDataAlgorithms( bool GetConvolveBackwardDataAlgorithms(
bool with_winograd_nonfused,
std::vector<dnn::AlgorithmType>* out_algorithms) override; std::vector<dnn::AlgorithmType>* out_algorithms) override;
bool GetConvolveBackwardFilterAlgorithms( bool GetConvolveBackwardFilterAlgorithms(
bool with_winograd_nonfused,
std::vector<dnn::AlgorithmType>* out_algorithms) override; std::vector<dnn::AlgorithmType>* out_algorithms) override;
bool DoBatchNormalizationForward( bool DoBatchNormalizationForward(

View File

@ -23,17 +23,17 @@ namespace gputools {
namespace dnn { namespace dnn {
bool DnnSupport::GetConvolveAlgorithms( bool DnnSupport::GetConvolveAlgorithms(
std::vector<AlgorithmType>* out_algorithms) { bool with_winograd_nonfused, std::vector<AlgorithmType>* out_algorithms) {
return false; return false;
} }
bool DnnSupport::GetConvolveBackwardDataAlgorithms( bool DnnSupport::GetConvolveBackwardDataAlgorithms(
std::vector<AlgorithmType>* out_algorithms) { bool with_winograd_nonfused, std::vector<AlgorithmType>* out_algorithms) {
return false; return false;
} }
bool DnnSupport::GetConvolveBackwardFilterAlgorithms( bool DnnSupport::GetConvolveBackwardFilterAlgorithms(
std::vector<AlgorithmType>* out_algorithms) { bool with_winograd_nonfused, std::vector<AlgorithmType>* out_algorithms) {
return false; return false;
} }

View File

@ -952,7 +952,7 @@ class DnnSupport {
// Return a list of algorithms supported by the forward convolution pass. // Return a list of algorithms supported by the forward convolution pass.
virtual bool GetConvolveAlgorithms( virtual bool GetConvolveAlgorithms(
std::vector<AlgorithmType>* out_algorithms); bool with_winograd_nonfused, std::vector<AlgorithmType>* out_algorithms);
// Enqueues a double-precision convolution operation onto the stream. // Enqueues a double-precision convolution operation onto the stream.
// See DoConvolve above for argument details. // See DoConvolve above for argument details.
@ -1056,7 +1056,7 @@ class DnnSupport {
// Return a list of algorithms supported by the backward convolution pass for // Return a list of algorithms supported by the backward convolution pass for
// data. // data.
virtual bool GetConvolveBackwardDataAlgorithms( virtual bool GetConvolveBackwardDataAlgorithms(
std::vector<AlgorithmType>* out_algorithms); bool with_winograd_nonfused, std::vector<AlgorithmType>* out_algorithms);
virtual bool DoConvolveBackwardData( virtual bool DoConvolveBackwardData(
Stream* stream, const FilterDescriptor& filter_descriptor, Stream* stream, const FilterDescriptor& filter_descriptor,
@ -1104,7 +1104,7 @@ class DnnSupport {
// Return a list of algorithms supported by the backward convolution pass for // Return a list of algorithms supported by the backward convolution pass for
// filters. // filters.
virtual bool GetConvolveBackwardFilterAlgorithms( virtual bool GetConvolveBackwardFilterAlgorithms(
std::vector<AlgorithmType>* out_algorithms); bool with_winograd_nonfused, std::vector<AlgorithmType>* out_algorithms);
virtual bool DoConvolveBackwardFilter( virtual bool DoConvolveBackwardFilter(
Stream* stream, const BatchDescriptor& input_descriptor, Stream* stream, const BatchDescriptor& input_descriptor,

View File

@ -285,30 +285,36 @@ bool StreamExecutor::SupportsDnn() const {
} }
bool StreamExecutor::GetConvolveAlgorithms( bool StreamExecutor::GetConvolveAlgorithms(
bool with_winograd_nonfused,
std::vector<dnn::AlgorithmType> *out_algorithms) { std::vector<dnn::AlgorithmType> *out_algorithms) {
dnn::DnnSupport *dnn_support = AsDnn(); dnn::DnnSupport *dnn_support = AsDnn();
if (!dnn_support) { if (!dnn_support) {
return false; return false;
} }
return dnn_support->GetConvolveAlgorithms(out_algorithms); return dnn_support->GetConvolveAlgorithms(with_winograd_nonfused,
out_algorithms);
} }
bool StreamExecutor::GetConvolveBackwardDataAlgorithms( bool StreamExecutor::GetConvolveBackwardDataAlgorithms(
bool with_winograd_nonfused,
std::vector<dnn::AlgorithmType> *out_algorithms) { std::vector<dnn::AlgorithmType> *out_algorithms) {
dnn::DnnSupport *dnn_support = AsDnn(); dnn::DnnSupport *dnn_support = AsDnn();
if (!dnn_support) { if (!dnn_support) {
return false; return false;
} }
return dnn_support->GetConvolveBackwardDataAlgorithms(out_algorithms); return dnn_support->GetConvolveBackwardDataAlgorithms(with_winograd_nonfused,
out_algorithms);
} }
bool StreamExecutor::GetConvolveBackwardFilterAlgorithms( bool StreamExecutor::GetConvolveBackwardFilterAlgorithms(
bool with_winograd_nonfused,
std::vector<dnn::AlgorithmType> *out_algorithms) { std::vector<dnn::AlgorithmType> *out_algorithms) {
dnn::DnnSupport *dnn_support = AsDnn(); dnn::DnnSupport *dnn_support = AsDnn();
if (!dnn_support) { if (!dnn_support) {
return false; return false;
} }
return dnn_support->GetConvolveBackwardFilterAlgorithms(out_algorithms); return dnn_support->GetConvolveBackwardFilterAlgorithms(
with_winograd_nonfused, out_algorithms);
} }
bool StreamExecutor::GetBlasGemmAlgorithms( bool StreamExecutor::GetBlasGemmAlgorithms(

View File

@ -342,15 +342,18 @@ class StreamExecutor {
bool SupportsDnn() const; bool SupportsDnn() const;
// Get the list of supported algorithms for the forward convolution opeartion. // Get the list of supported algorithms for the forward convolution opeartion.
bool GetConvolveAlgorithms(std::vector<dnn::AlgorithmType> *out_algorithms); bool GetConvolveAlgorithms(bool with_winograd_nonfused,
std::vector<dnn::AlgorithmType> *out_algorithms);
// Get the list of supported algorithms for the backward convolution on data. // Get the list of supported algorithms for the backward convolution on data.
bool GetConvolveBackwardDataAlgorithms( bool GetConvolveBackwardDataAlgorithms(
bool with_winograd_nonfused,
std::vector<dnn::AlgorithmType> *out_algorithms); std::vector<dnn::AlgorithmType> *out_algorithms);
// Get the list of supported algorithms for the backward convolution on the // Get the list of supported algorithms for the backward convolution on the
// filter. // filter.
bool GetConvolveBackwardFilterAlgorithms( bool GetConvolveBackwardFilterAlgorithms(
bool with_winograd_nonfused,
std::vector<dnn::AlgorithmType> *out_algorithms); std::vector<dnn::AlgorithmType> *out_algorithms);
// Get the list of supported algorithms for BLAS gemm. // Get the list of supported algorithms for BLAS gemm.