mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Set winograd nofused flag to be true by default.
Disable winograd nonfused conv for certain input params to avoid a known bug in cuDNNv5 and cuDNNv6. PiperOrigin-RevId: 157352847
This commit is contained in:
parent
3f9b69a50f
commit
e78e5ec8a8
|
|
@ -258,15 +258,21 @@ tensorflow::Status ConvolutionThunk::Convolve(
|
|||
std::vector<se::dnn::AlgorithmType> ConvolutionThunk::GetAlgorithms(
|
||||
se::StreamExecutor* stream_exec) const {
|
||||
std::vector<se::dnn::AlgorithmType> algorithms;
|
||||
// TODO(yangzihao): Currently disable the use of winograd nonfused in XLA
|
||||
// by default. Should send in conv parameters and enable it when
|
||||
// ShouldIncludeWinogradNonfusedAlgo() returns true.
|
||||
switch (convolution_kind_) {
|
||||
case ConvolutionKind::kBackwardFilter:
|
||||
CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms(&algorithms));
|
||||
CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms(
|
||||
/*with_winograd_nonfused=*/false, &algorithms));
|
||||
break;
|
||||
case ConvolutionKind::kBackwardInput:
|
||||
CHECK(stream_exec->GetConvolveBackwardDataAlgorithms(&algorithms));
|
||||
CHECK(stream_exec->GetConvolveBackwardDataAlgorithms(
|
||||
/*with_winograd_nonfused=*/false, &algorithms));
|
||||
break;
|
||||
case ConvolutionKind::kForward:
|
||||
CHECK(stream_exec->GetConvolveAlgorithms(&algorithms));
|
||||
CHECK(stream_exec->GetConvolveAlgorithms(/*with_winograd_nonfused=*/false,
|
||||
&algorithms));
|
||||
break;
|
||||
}
|
||||
return algorithms;
|
||||
|
|
|
|||
|
|
@ -776,7 +776,8 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
|
|||
if (cudnn_use_autotune_ && !AutoTuneConvBwdFilter::GetInstance()->Find(
|
||||
conv_parameters, &algorithm_config)) {
|
||||
std::vector<AlgorithmType> algorithms;
|
||||
CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(&algorithms));
|
||||
CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(
|
||||
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
|
||||
ProfileResult best_result;
|
||||
ProfileResult best_result_no_scratch;
|
||||
for (auto profile_algorithm : algorithms) {
|
||||
|
|
|
|||
|
|
@ -856,7 +856,8 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
|
|||
if (cudnn_use_autotune_ && !AutoTuneConvBwdData::GetInstance()->Find(
|
||||
conv_parameters, &algorithm_config)) {
|
||||
std::vector<AlgorithmType> algorithms;
|
||||
CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(&algorithms));
|
||||
CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
|
||||
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
|
||||
ProfileResult best_result;
|
||||
ProfileResult best_result_no_scratch;
|
||||
for (auto profile_algorithm : algorithms) {
|
||||
|
|
|
|||
|
|
@ -656,7 +656,8 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
|
|||
if (cudnn_use_autotune_ && !AutoTuneConv3dBwdData::GetInstance()->Find(
|
||||
conv_parameters, &algorithm_config)) {
|
||||
std::vector<AlgorithmType> algorithms;
|
||||
CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(&algorithms));
|
||||
CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
|
||||
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
|
||||
ProfileResult best_result;
|
||||
ProfileResult best_result_no_scratch;
|
||||
for (auto profile_algorithm : algorithms) {
|
||||
|
|
@ -1020,11 +1021,11 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
|
|||
using perftools::gputools::dnn::ProfileResult;
|
||||
using perftools::gputools::dnn::kDefaultAlgorithm;
|
||||
AlgorithmConfig algorithm_config;
|
||||
|
||||
if (cudnn_use_autotune_ && !AutoTuneConv3dBwdFilter::GetInstance()->Find(
|
||||
conv_parameters, &algorithm_config)) {
|
||||
std::vector<AlgorithmType> algorithms;
|
||||
CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(&algorithms));
|
||||
CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(
|
||||
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
|
||||
ProfileResult best_result;
|
||||
ProfileResult best_result_no_scratch;
|
||||
for (auto profile_algorithm : algorithms) {
|
||||
|
|
|
|||
|
|
@ -668,7 +668,8 @@ void LaunchConv2DOp<GPUDevice, T>::launch(
|
|||
if (cudnn_use_autotune &&
|
||||
!AutoTuneConv::GetInstance()->Find(conv_parameters, &algorithm_config)) {
|
||||
std::vector<AlgorithmType> algorithms;
|
||||
CHECK(stream->parent()->GetConvolveAlgorithms(&algorithms));
|
||||
CHECK(stream->parent()->GetConvolveAlgorithms(
|
||||
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
|
||||
ProfileResult best_result;
|
||||
ProfileResult best_result_no_scratch;
|
||||
for (auto profile_algorithm : algorithms) {
|
||||
|
|
|
|||
|
|
@ -392,7 +392,8 @@ struct LaunchConvOp<GPUDevice, T> {
|
|||
if (cudnn_use_autotune && !AutoTuneConv3d::GetInstance()->Find(
|
||||
conv_parameters, &algorithm_config)) {
|
||||
std::vector<AlgorithmType> algorithms;
|
||||
CHECK(stream->parent()->GetConvolveAlgorithms(&algorithms));
|
||||
CHECK(stream->parent()->GetConvolveAlgorithms(
|
||||
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
|
||||
ProfileResult best_result;
|
||||
ProfileResult best_result_no_scratch;
|
||||
for (auto profile_algorithm : algorithms) {
|
||||
|
|
|
|||
|
|
@ -145,6 +145,22 @@ class ConvParameters {
|
|||
// clang-format on
|
||||
}
|
||||
|
||||
// TODO(yangzihao): The purpose of this function is to disable winograd
|
||||
// nonfused conv algorithm for certain input parameters so as to avoid a bug
|
||||
// in cuDNNv5 and cuDNNv6. Remove this once switch to cuDNNv7.
|
||||
template <typename T>
|
||||
bool ShouldIncludeWinogradNonfusedAlgo() const {
|
||||
int64 total_size = 16 * std::ceil(batch_ / 16.0) *
|
||||
std::max(in_depths_, out_depths_) * in_[0] * in_[1] *
|
||||
sizeof(T);
|
||||
int64 threshold = 1L << 31;
|
||||
if (total_size >= threshold) {
|
||||
return false;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
typedef std::tuple<int64, int64, SpatialArray, int64, SpatialArray,
|
||||
SpatialArray, SpatialArray, DataType, int>
|
||||
|
|
|
|||
|
|
@ -28,8 +28,50 @@ limitations under the License.
|
|||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
#include "tensorflow/core/kernels/conv_ops_gpu.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
TEST(ConvParameters, WinogradNonfusedAlgoSize) {
|
||||
ConvParameters conv_params_small = {
|
||||
1, // batch
|
||||
32, // in_depths
|
||||
{{300, // in_rows
|
||||
300}}, // in_cols
|
||||
128, // out_depths
|
||||
{{3, // filter_rows
|
||||
3}}, // filter_cols
|
||||
{{1, // stride_rows
|
||||
1}}, // stride_cols
|
||||
{{0, // padding_rows
|
||||
0}}, // padding_cols
|
||||
DT_FLOAT, // tensor datatype
|
||||
0, // device_id
|
||||
};
|
||||
EXPECT_TRUE(conv_params_small.ShouldIncludeWinogradNonfusedAlgo<float>());
|
||||
|
||||
ConvParameters conv_params_large = {
|
||||
1, // batch
|
||||
128, // in_depths
|
||||
{{300, // in_rows
|
||||
300}}, // in_cols
|
||||
768, // out_depths
|
||||
{{3, // filter_rows
|
||||
3}}, // filter_cols
|
||||
{{1, // stride_rows
|
||||
1}}, // stride_cols
|
||||
{{0, // padding_rows
|
||||
0}}, // padding_cols
|
||||
DT_FLOAT, // tensor datatype
|
||||
0, // device_id
|
||||
};
|
||||
EXPECT_FALSE(conv_params_large.ShouldIncludeWinogradNonfusedAlgo<float>());
|
||||
}
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
class FusedResizePadConvOpTest : public OpsTestBase {
|
||||
protected:
|
||||
void HandwrittenConv() {
|
||||
|
|
|
|||
|
|
@ -2071,7 +2071,7 @@ cuda_py_test(
|
|||
|
||||
cuda_py_test(
|
||||
name = "conv_ops_test",
|
||||
size = "medium",
|
||||
size = "large",
|
||||
srcs = ["conv_ops_test.py"],
|
||||
additional_deps = [
|
||||
"//third_party/py/numpy",
|
||||
|
|
@ -2089,6 +2089,7 @@ cuda_py_test(
|
|||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
shard_count = 4,
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
|
|
|
|||
|
|
@ -189,7 +189,7 @@ class Conv2DTest(test.TestCase):
|
|||
# numbers from 1.
|
||||
x1 = [f * 1.0 for f in range(1, total_size_1 + 1)]
|
||||
x2 = [f * 1.0 for f in range(1, total_size_2 + 1)]
|
||||
with self.test_session(use_gpu=use_gpu) as sess:
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype)
|
||||
t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype)
|
||||
strides = [1] + strides + [1]
|
||||
|
|
@ -378,7 +378,7 @@ class Conv2DTest(test.TestCase):
|
|||
expected=[50, 60])
|
||||
|
||||
# TODO this currently fails.
|
||||
#self._VerifyValues(tensor_in_sizes=[1, 8, 8, 1],
|
||||
# self._VerifyValues(tensor_in_sizes=[1, 8, 8, 1],
|
||||
# filter_in_sizes=[2, 2, 1, 1],
|
||||
# strides=[4, 4], padding="SAME",
|
||||
# expected=[72, 112, 392, 432])
|
||||
|
|
@ -424,7 +424,7 @@ class Conv2DTest(test.TestCase):
|
|||
x2 = np.random.rand(*output_sizes).astype(np.float32)
|
||||
|
||||
def _GetVal(data_format, use_gpu):
|
||||
with self.test_session(use_gpu=use_gpu) as sess:
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
if data_format == "NCHW":
|
||||
new_input_sizes = test_util.NHWCToNCHW(input_sizes)
|
||||
else:
|
||||
|
|
@ -580,7 +580,7 @@ class Conv2DTest(test.TestCase):
|
|||
x2 = np.random.rand(*output_sizes).astype(np.float32)
|
||||
|
||||
def _GetVal(data_format, use_gpu):
|
||||
with self.test_session(use_gpu=use_gpu) as sess:
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
t0 = constant_op.constant(x0, shape=input_sizes)
|
||||
t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)])
|
||||
t2 = constant_op.constant(x2, shape=output_sizes)
|
||||
|
|
@ -1444,4 +1444,19 @@ if __name__ == "__main__":
|
|||
GetInceptionBackFilterTest(input_size_, filter_size_, output_size_,
|
||||
[stride_, stride_], padding_))
|
||||
|
||||
# TODO(b/35359731)
|
||||
# Fwd, BckInput, and BackFilter to test that for certain input parameter
|
||||
# set, winograd nonfused algorithm will be excluded from conv autotune. If
|
||||
# in such case, winograd nonfused algorithm is added as one option of the
|
||||
# conv autotune, and cuDNN version is smaller than 7, the following tests
|
||||
# will fail.
|
||||
ishape = [1, 400, 400, 128]
|
||||
fshape = [3, 3, 128, 768]
|
||||
oshape = [1, 400, 400, 768]
|
||||
setattr(Conv2DTest, "testInceptionFwd_No_Winograd_Nonfused",
|
||||
GetInceptionFwdTest(ishape, fshape, 1, "SAME"))
|
||||
setattr(Conv2DTest, "testInceptionBackInput_No_Winograd_Nonfused",
|
||||
GetInceptionBackInputTest(ishape, fshape, oshape, 1, "SAME"))
|
||||
setattr(Conv2DTest, "testInceptionBackFilter_No_Winograd_Nonfused",
|
||||
GetInceptionBackFilterTest(ishape, fshape, oshape, [1, 1], "SAME"))
|
||||
test.main()
|
||||
|
|
|
|||
|
|
@ -1966,12 +1966,12 @@ bool CudnnSupport::DoConvolveImpl(
|
|||
}
|
||||
|
||||
// A helper class to decide whether to enable the WINOGRAD_NONFUSED algorithms.
|
||||
// Doing so by default make a few TensorFlow test cases to fail. Users can
|
||||
// explicitly enable them through an env-var "TF_ENABLE_WINOGRAD_NONFUSED=1".
|
||||
// By default it is turned on, users can explicitly disable them through an
|
||||
// env-var "TF_ENABLE_WINOGRAD_NONFUSED=0".
|
||||
// https://github.com/tensorflow/tensorflow/pull/4901
|
||||
// TODO(yangzihao): for certain shapes, setting default flag to be true will
|
||||
// cause bug and return negative tensor shapes. Will flip the default flag when
|
||||
// the bug is fixed.
|
||||
// TODO(yangzihao): winograd_nonfused bug will only be fixed in cuDNNv7, for
|
||||
// cuDNN with smaller versions, we have added code to avoid using winograd
|
||||
// nonfused for certain input parameter set.
|
||||
template <bool DefaultFlag>
|
||||
class WinogradNonfused {
|
||||
public:
|
||||
|
|
@ -1997,6 +1997,7 @@ class WinogradNonfused {
|
|||
};
|
||||
|
||||
bool CudnnSupport::GetConvolveAlgorithms(
|
||||
bool with_winograd_nonfused,
|
||||
std::vector<dnn::AlgorithmType>* out_algorithms) {
|
||||
out_algorithms->assign({
|
||||
// clang-format off
|
||||
|
|
@ -2012,7 +2013,7 @@ bool CudnnSupport::GetConvolveAlgorithms(
|
|||
// clang-format on
|
||||
});
|
||||
#if CUDNN_VERSION >= 5100
|
||||
if (WinogradNonfused<false>::IsEnabled()) {
|
||||
if (WinogradNonfused<true>::IsEnabled() && with_winograd_nonfused) {
|
||||
out_algorithms->push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED);
|
||||
}
|
||||
#endif
|
||||
|
|
@ -2020,6 +2021,7 @@ bool CudnnSupport::GetConvolveAlgorithms(
|
|||
}
|
||||
|
||||
bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
|
||||
bool with_winograd_nonfused,
|
||||
std::vector<dnn::AlgorithmType>* out_algorithms) {
|
||||
out_algorithms->assign({
|
||||
// clang-format off
|
||||
|
|
@ -2033,7 +2035,7 @@ bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
|
|||
// clang-format on
|
||||
});
|
||||
#if CUDNN_VERSION >= 5100
|
||||
if (WinogradNonfused<false>::IsEnabled()) {
|
||||
if (WinogradNonfused<true>::IsEnabled() && with_winograd_nonfused) {
|
||||
out_algorithms->push_back(
|
||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED);
|
||||
}
|
||||
|
|
@ -2042,6 +2044,7 @@ bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
|
|||
}
|
||||
|
||||
bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
|
||||
bool with_winograd_nonfused,
|
||||
std::vector<dnn::AlgorithmType>* out_algorithms) {
|
||||
out_algorithms->assign({
|
||||
// clang-format off
|
||||
|
|
@ -2053,11 +2056,12 @@ bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
|
|||
});
|
||||
#if CUDNN_VERSION >= 5100
|
||||
#if CUDNN_VERSION >= 5110
|
||||
static constexpr bool kDefaultFlagWinogradNonfused = false;
|
||||
static constexpr bool kDefaultFlagWinogradNonfused = true;
|
||||
#else
|
||||
static constexpr bool kDefaultFlagWinogradNonfused = false;
|
||||
#endif
|
||||
if (WinogradNonfused<kDefaultFlagWinogradNonfused>::IsEnabled()) {
|
||||
if (WinogradNonfused<kDefaultFlagWinogradNonfused>::IsEnabled() &&
|
||||
with_winograd_nonfused) {
|
||||
out_algorithms->push_back(
|
||||
// Based on cudnn.h, the following is not implemented.
|
||||
// CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD,
|
||||
|
|
|
|||
|
|
@ -104,12 +104,15 @@ class CudnnSupport : public dnn::DnnSupport {
|
|||
ScratchAllocator* workspace_allocator) override;
|
||||
|
||||
bool GetConvolveAlgorithms(
|
||||
bool with_winograd_nonfused,
|
||||
std::vector<dnn::AlgorithmType>* out_algorithms) override;
|
||||
|
||||
bool GetConvolveBackwardDataAlgorithms(
|
||||
bool with_winograd_nonfused,
|
||||
std::vector<dnn::AlgorithmType>* out_algorithms) override;
|
||||
|
||||
bool GetConvolveBackwardFilterAlgorithms(
|
||||
bool with_winograd_nonfused,
|
||||
std::vector<dnn::AlgorithmType>* out_algorithms) override;
|
||||
|
||||
bool DoBatchNormalizationForward(
|
||||
|
|
|
|||
|
|
@ -23,17 +23,17 @@ namespace gputools {
|
|||
namespace dnn {
|
||||
|
||||
bool DnnSupport::GetConvolveAlgorithms(
|
||||
std::vector<AlgorithmType>* out_algorithms) {
|
||||
bool with_winograd_nonfused, std::vector<AlgorithmType>* out_algorithms) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool DnnSupport::GetConvolveBackwardDataAlgorithms(
|
||||
std::vector<AlgorithmType>* out_algorithms) {
|
||||
bool with_winograd_nonfused, std::vector<AlgorithmType>* out_algorithms) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool DnnSupport::GetConvolveBackwardFilterAlgorithms(
|
||||
std::vector<AlgorithmType>* out_algorithms) {
|
||||
bool with_winograd_nonfused, std::vector<AlgorithmType>* out_algorithms) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -952,7 +952,7 @@ class DnnSupport {
|
|||
|
||||
// Return a list of algorithms supported by the forward convolution pass.
|
||||
virtual bool GetConvolveAlgorithms(
|
||||
std::vector<AlgorithmType>* out_algorithms);
|
||||
bool with_winograd_nonfused, std::vector<AlgorithmType>* out_algorithms);
|
||||
|
||||
// Enqueues a double-precision convolution operation onto the stream.
|
||||
// See DoConvolve above for argument details.
|
||||
|
|
@ -1056,7 +1056,7 @@ class DnnSupport {
|
|||
// Return a list of algorithms supported by the backward convolution pass for
|
||||
// data.
|
||||
virtual bool GetConvolveBackwardDataAlgorithms(
|
||||
std::vector<AlgorithmType>* out_algorithms);
|
||||
bool with_winograd_nonfused, std::vector<AlgorithmType>* out_algorithms);
|
||||
|
||||
virtual bool DoConvolveBackwardData(
|
||||
Stream* stream, const FilterDescriptor& filter_descriptor,
|
||||
|
|
@ -1104,7 +1104,7 @@ class DnnSupport {
|
|||
// Return a list of algorithms supported by the backward convolution pass for
|
||||
// filters.
|
||||
virtual bool GetConvolveBackwardFilterAlgorithms(
|
||||
std::vector<AlgorithmType>* out_algorithms);
|
||||
bool with_winograd_nonfused, std::vector<AlgorithmType>* out_algorithms);
|
||||
|
||||
virtual bool DoConvolveBackwardFilter(
|
||||
Stream* stream, const BatchDescriptor& input_descriptor,
|
||||
|
|
|
|||
|
|
@ -285,30 +285,36 @@ bool StreamExecutor::SupportsDnn() const {
|
|||
}
|
||||
|
||||
bool StreamExecutor::GetConvolveAlgorithms(
|
||||
bool with_winograd_nonfused,
|
||||
std::vector<dnn::AlgorithmType> *out_algorithms) {
|
||||
dnn::DnnSupport *dnn_support = AsDnn();
|
||||
if (!dnn_support) {
|
||||
return false;
|
||||
}
|
||||
return dnn_support->GetConvolveAlgorithms(out_algorithms);
|
||||
return dnn_support->GetConvolveAlgorithms(with_winograd_nonfused,
|
||||
out_algorithms);
|
||||
}
|
||||
|
||||
bool StreamExecutor::GetConvolveBackwardDataAlgorithms(
|
||||
bool with_winograd_nonfused,
|
||||
std::vector<dnn::AlgorithmType> *out_algorithms) {
|
||||
dnn::DnnSupport *dnn_support = AsDnn();
|
||||
if (!dnn_support) {
|
||||
return false;
|
||||
}
|
||||
return dnn_support->GetConvolveBackwardDataAlgorithms(out_algorithms);
|
||||
return dnn_support->GetConvolveBackwardDataAlgorithms(with_winograd_nonfused,
|
||||
out_algorithms);
|
||||
}
|
||||
|
||||
bool StreamExecutor::GetConvolveBackwardFilterAlgorithms(
|
||||
bool with_winograd_nonfused,
|
||||
std::vector<dnn::AlgorithmType> *out_algorithms) {
|
||||
dnn::DnnSupport *dnn_support = AsDnn();
|
||||
if (!dnn_support) {
|
||||
return false;
|
||||
}
|
||||
return dnn_support->GetConvolveBackwardFilterAlgorithms(out_algorithms);
|
||||
return dnn_support->GetConvolveBackwardFilterAlgorithms(
|
||||
with_winograd_nonfused, out_algorithms);
|
||||
}
|
||||
|
||||
bool StreamExecutor::GetBlasGemmAlgorithms(
|
||||
|
|
|
|||
|
|
@ -342,15 +342,18 @@ class StreamExecutor {
|
|||
bool SupportsDnn() const;
|
||||
|
||||
// Get the list of supported algorithms for the forward convolution opeartion.
|
||||
bool GetConvolveAlgorithms(std::vector<dnn::AlgorithmType> *out_algorithms);
|
||||
bool GetConvolveAlgorithms(bool with_winograd_nonfused,
|
||||
std::vector<dnn::AlgorithmType> *out_algorithms);
|
||||
|
||||
// Get the list of supported algorithms for the backward convolution on data.
|
||||
bool GetConvolveBackwardDataAlgorithms(
|
||||
bool with_winograd_nonfused,
|
||||
std::vector<dnn::AlgorithmType> *out_algorithms);
|
||||
|
||||
// Get the list of supported algorithms for the backward convolution on the
|
||||
// filter.
|
||||
bool GetConvolveBackwardFilterAlgorithms(
|
||||
bool with_winograd_nonfused,
|
||||
std::vector<dnn::AlgorithmType> *out_algorithms);
|
||||
|
||||
// Get the list of supported algorithms for BLAS gemm.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user