mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Make bias optional in cuDNN conv op
Summary: Yangqing This seems to work for me, not sure if it's implemented in the right way for you to accept :) Allows user to specify "no_bias" as an option for convolution layers (only cuDNN at this point), so that the bias associated with that operator is not allocated or computed. This is useful in particular for conv + BatchNorm combinations (such as ResNets), as the bias term can be handled by both conv and Batch Norm, wasting memory and computation. Closes https://github.com/caffe2/caffe2/pull/50 Reviewed By: Yangqing Differential Revision: D4341288 Pulled By: bwasti fbshipit-source-id: e6138d0024c83ed876dff2f83ffbebe7de502fd8
This commit is contained in:
parent
fe38a0c2b1
commit
05233cd5b8
|
|
@ -7,7 +7,7 @@ REGISTER_CPU_OPERATOR(Conv, ConvOp<float, CPUContext>);
|
|||
REGISTER_CPU_OPERATOR(ConvGradient, ConvGradientOp<float, CPUContext>);
|
||||
|
||||
OPERATOR_SCHEMA(Conv)
|
||||
.NumInputs(3)
|
||||
.NumInputs(2,3)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
The convolution operator consumes an input vector, the filter blob and the bias
|
||||
|
|
@ -35,16 +35,32 @@ why they are separate files.
|
|||
"stride size, and pad lengths."
|
||||
"");
|
||||
|
||||
OPERATOR_SCHEMA(ConvGradient).NumInputs(3).NumOutputs(2, 3);
|
||||
OPERATOR_SCHEMA(ConvGradient).NumInputs(2,3).NumOutputs(2, 3);
|
||||
|
||||
class GetConvGradient : public GradientMakerBase {
|
||||
using GradientMakerBase::GradientMakerBase;
|
||||
vector<OperatorDef> GetGradientDefs() override {
|
||||
CAFFE_ENFORCE(3 == def_.input_size());
|
||||
// todo(slayton) don't do this.
|
||||
// CAFFE_ENFORCE(3 == def_.input_size());
|
||||
|
||||
vector<string> inputs, outputs;
|
||||
|
||||
ArgumentHelper helper(def_);
|
||||
bool no_bias = static_cast<bool>(helper.GetSingleArgument<int>("no_bias", 0));
|
||||
|
||||
if (no_bias) {
|
||||
// no bias - same inputs, only output dW, dY
|
||||
inputs = vector<string>{I(0), I(1), GO(0)};
|
||||
outputs = vector<string>{GI(1), GI(0)};
|
||||
} else {
|
||||
inputs = vector<string>{I(0), I(1), GO(0)};
|
||||
outputs = vector<string>{GI(1), GI(2), GI(0)};
|
||||
}
|
||||
|
||||
return SingleGradientDef(
|
||||
"ConvGradient", "",
|
||||
vector<string>{I(0), I(1), GO(0)},
|
||||
vector<string>{GI(1), GI(2), GI(0)});
|
||||
inputs,
|
||||
outputs);
|
||||
}
|
||||
};
|
||||
REGISTER_GRADIENT(Conv, GetConvGradient);
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ class ConvGradientOp final : public ConvPoolOpBase<Context> {
|
|||
// input: X, W, dY
|
||||
// output: dW, db, and optionally dX
|
||||
INPUT_TAGS(INPUT, FILTER, OUTPUT_GRAD);
|
||||
OUTPUT_TAGS(FILTER_GRAD, BIAS_GRAD, INPUT_GRAD);
|
||||
OUTPUT_TAGS(FILTER_GRAD, BIAS_OR_INPUT_GRAD, INPUT_GRAD);
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#include "caffe2/core/common_cudnn.h"
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "caffe2/operators/conv_op.h"
|
||||
#include "caffe2/operators/conv_op_cache_cudnn.h"
|
||||
#include "caffe2/operators/conv_pool_op_base.h"
|
||||
|
||||
|
|
@ -44,8 +45,10 @@ class CudnnConvOpBase : public ConvPoolOpBase<CUDAContext> {
|
|||
OperatorBase::GetSingleArgument<int>("exhaustive_search", 0)),
|
||||
deterministic_(
|
||||
OperatorBase::GetSingleArgument<int>("deterministic", 0)),
|
||||
cudnn_state_(OperatorBase::GetSingleArgument<int>("cudnn_state", 0)) {
|
||||
CAFFE_ENFORCE(!deterministic_ || !exhaustive_search_);
|
||||
cudnn_state_(OperatorBase::GetSingleArgument<int>("cudnn_state", 0)),
|
||||
no_bias_(OperatorBase::GetSingleArgument<int>("no_bias", 0)) {
|
||||
bias_ = (!no_bias_) ? true : false;
|
||||
CHECK(!deterministic_ || !exhaustive_search_);
|
||||
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bottom_desc_));
|
||||
CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc_));
|
||||
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc_));
|
||||
|
|
@ -76,6 +79,8 @@ class CudnnConvOpBase : public ConvPoolOpBase<CUDAContext> {
|
|||
bool exhaustive_search_;
|
||||
bool deterministic_;
|
||||
size_t cudnn_state_;
|
||||
int no_bias_;
|
||||
bool bias_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -114,7 +119,7 @@ class CudnnConvGradientOp final : public CudnnConvOpBase {
|
|||
// input: X, W, dY
|
||||
// output: dW, db, and optionally dX
|
||||
INPUT_TAGS(INPUT, FILTER, OUTPUT_GRAD);
|
||||
OUTPUT_TAGS(FILTER_GRAD, BIAS_GRAD, INPUT_GRAD);
|
||||
OUTPUT_TAGS(FILTER_GRAD, BIAS_OR_INPUT_GRAD, INPUT_GRAD);
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
@ -125,7 +130,6 @@ template <typename T>
|
|||
bool CudnnConvOp<T>::RunOnDevice() {
|
||||
auto& X = Input(INPUT);
|
||||
auto& filter = Input(FILTER);
|
||||
auto& bias = Input(BIAS);
|
||||
auto* Y = Output(0);
|
||||
|
||||
// Figure out the output shape
|
||||
|
|
@ -152,8 +156,6 @@ bool CudnnConvOp<T>::RunOnDevice() {
|
|||
default:
|
||||
LOG(FATAL) << "Unknown storage order: " << order_;
|
||||
}
|
||||
DCHECK_EQ(bias.ndim(), 1);
|
||||
DCHECK_EQ(bias.dim32(0), M);
|
||||
|
||||
// Set up the cudnn algorithms & workspace if necessary
|
||||
bool input_changed = (X.dims() != cudnn_input_dims_);
|
||||
|
|
@ -176,9 +178,11 @@ bool CudnnConvOp<T>::RunOnDevice() {
|
|||
C,
|
||||
kernel_h_,
|
||||
kernel_w_));
|
||||
CUDNN_CHECK(cudnnSetTensor4dDescriptor(
|
||||
bias_desc_, GetCudnnTensorFormat(order_), cudnnTypeWrapper<T>::type,
|
||||
1, M, 1, 1));
|
||||
if (bias_) {
|
||||
CUDNN_CHECK(cudnnSetTensor4dDescriptor(
|
||||
bias_desc_, GetCudnnTensorFormat(order_), cudnnTypeWrapper<T>::type,
|
||||
1, M, 1, 1));
|
||||
}
|
||||
}
|
||||
// Set the output
|
||||
CUDNN_CHECK(cudnnSetTensor4dDescriptor(
|
||||
|
|
@ -267,14 +271,21 @@ bool CudnnConvOp<T>::RunOnDevice() {
|
|||
Y->template mutable_data<T>()));
|
||||
});
|
||||
// Bias
|
||||
CUDNN_CHECK(cudnnAddTensor(
|
||||
cudnn_wrapper_.inline_cudnn_handle(),
|
||||
cudnnTypeWrapper<T>::kOne(),
|
||||
bias_desc_,
|
||||
bias.template data<T>(),
|
||||
cudnnTypeWrapper<T>::kOne(),
|
||||
top_desc_,
|
||||
Y->template mutable_data<T>()));
|
||||
if (bias_) {
|
||||
auto& bias = Input(BIAS);
|
||||
|
||||
DCHECK_EQ(bias.ndim(), 1);
|
||||
DCHECK_EQ(bias.dim32(0), M);
|
||||
|
||||
CUDNN_CHECK(cudnnAddTensor(
|
||||
cudnn_wrapper_.inline_cudnn_handle(),
|
||||
cudnnTypeWrapper<T>::kOne(),
|
||||
bias_desc_,
|
||||
bias.template data<T>(),
|
||||
cudnnTypeWrapper<T>::kOne(),
|
||||
top_desc_,
|
||||
Y->template mutable_data<T>()));
|
||||
}
|
||||
// Done.
|
||||
return true;
|
||||
}
|
||||
|
|
@ -287,7 +298,6 @@ bool CudnnConvGradientOp<T>::RunOnDevice() {
|
|||
auto& filter = Input(FILTER);
|
||||
auto& dY = Input(OUTPUT_GRAD);
|
||||
auto* dfilter = Output(FILTER_GRAD);
|
||||
auto* dbias = Output(BIAS_GRAD);
|
||||
|
||||
DCHECK_EQ(X.ndim(), 4);
|
||||
DCHECK_EQ(filter.ndim(), 4);
|
||||
|
|
@ -313,7 +323,6 @@ bool CudnnConvGradientOp<T>::RunOnDevice() {
|
|||
}
|
||||
ConvPoolOpBase<CUDAContext>::ComputePads(H, W);
|
||||
dfilter->ResizeLike(filter);
|
||||
dbias->Resize(TIndex(M));
|
||||
|
||||
// Set up the cudnn algorithms & workspace if necessary
|
||||
bool input_changed = (X.dims() != cudnn_input_dims_);
|
||||
|
|
@ -336,9 +345,11 @@ bool CudnnConvGradientOp<T>::RunOnDevice() {
|
|||
C,
|
||||
kernel_h_,
|
||||
kernel_w_));
|
||||
CUDNN_CHECK(cudnnSetTensor4dDescriptor(
|
||||
bias_desc_, GetCudnnTensorFormat(order_), cudnnTypeWrapper<T>::type,
|
||||
1, M, 1, 1));
|
||||
if (bias_) {
|
||||
CUDNN_CHECK(cudnnSetTensor4dDescriptor(
|
||||
bias_desc_, GetCudnnTensorFormat(order_), cudnnTypeWrapper<T>::type,
|
||||
1, M, 1, 1));
|
||||
}
|
||||
}
|
||||
// Set the output
|
||||
CUDNN_CHECK(cudnnSetTensor4dDescriptor(
|
||||
|
|
@ -404,7 +415,7 @@ bool CudnnConvGradientOp<T>::RunOnDevice() {
|
|||
return filter_perf_stat[0].algo;
|
||||
});
|
||||
|
||||
if (OutputSize() == 3) {
|
||||
if (OutputSize() == 3 || (!bias_ && (OutputSize() == 2))) {
|
||||
bwd_data_algo_ =
|
||||
data_algo_cache_.getAlgorithm(X.dims(), filter.dims(), [&]() {
|
||||
VLOG(1) << "CUDNN Convolution bwd: doing data exhaustive search.";
|
||||
|
|
@ -416,7 +427,7 @@ bool CudnnConvGradientOp<T>::RunOnDevice() {
|
|||
data_perf_stat;
|
||||
cudnn_wrapper_.with_cudnn_state(
|
||||
cudnn_state_, [&](CuDNNState* state) {
|
||||
auto* dX = Output(INPUT_GRAD);
|
||||
auto* dX = Output(bias_ ? INPUT_GRAD : BIAS_OR_INPUT_GRAD);
|
||||
dX->ResizeLike(X);
|
||||
CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithmEx(
|
||||
state->cudnn_handle(),
|
||||
|
|
@ -470,10 +481,14 @@ bool CudnnConvGradientOp<T>::RunOnDevice() {
|
|||
}
|
||||
|
||||
// Now, actually run the computation.
|
||||
CUDNN_CHECK(cudnnConvolutionBackwardBias(
|
||||
cudnn_wrapper_.inline_cudnn_handle(), cudnnTypeWrapper<T>::kOne(), top_desc_,
|
||||
dY.template data<T>(), cudnnTypeWrapper<T>::kZero(), bias_desc_,
|
||||
dbias->template mutable_data<T>()));
|
||||
if (bias_) {
|
||||
auto* dbias = Output(BIAS_OR_INPUT_GRAD);
|
||||
dbias->Resize(TIndex(M));
|
||||
CUDNN_CHECK(cudnnConvolutionBackwardBias(
|
||||
cudnn_wrapper_.inline_cudnn_handle(), cudnnTypeWrapper<T>::kOne(), top_desc_,
|
||||
dY.template data<T>(), cudnnTypeWrapper<T>::kZero(), bias_desc_,
|
||||
dbias->template mutable_data<T>()));
|
||||
}
|
||||
|
||||
cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
|
||||
CUDNN_CHECK(cudnnConvolutionBackwardFilter(
|
||||
|
|
@ -490,9 +505,9 @@ bool CudnnConvGradientOp<T>::RunOnDevice() {
|
|||
cudnnTypeWrapper<T>::kZero(),
|
||||
filter_desc_,
|
||||
dfilter->template mutable_data<T>()));
|
||||
if (OutputSize() == 3) {
|
||||
if (OutputSize() == 3 || (!bias_ && (OutputSize() == 2))) {
|
||||
// Compute the gradient w.r.t. the input.
|
||||
auto* dX = Output(INPUT_GRAD);
|
||||
auto* dX = Output(bias_ ? INPUT_GRAD : BIAS_OR_INPUT_GRAD);
|
||||
dX->ResizeLike(X);
|
||||
CUDNN_CHECK(cudnnConvolutionBackwardData(
|
||||
state->cudnn_handle(),
|
||||
|
|
|
|||
|
|
@ -274,7 +274,7 @@ bool ConvGradientOp<T, Context>::RunOnDeviceWithOrderNCHW() {
|
|||
auto& filter = Input(FILTER);
|
||||
auto& dY = Input(OUTPUT_GRAD);
|
||||
auto* dfilter = Output(FILTER_GRAD);
|
||||
auto* dbias = Output(BIAS_GRAD);
|
||||
auto* dbias = Output(BIAS_OR_INPUT_GRAD);
|
||||
const int N = X.dim32(0), C = X.dim32(1), H = X.dim32(2), W = X.dim32(3);
|
||||
ConvPoolOpBase<Context>::ComputePads(H, W);
|
||||
CAFFE_ENFORCE(4 == filter.ndim());
|
||||
|
|
@ -407,7 +407,7 @@ bool ConvGradientOp<T, Context>::RunOnDeviceWithOrderNHWC() {
|
|||
auto& filter = Input(FILTER);
|
||||
auto& dY = Input(OUTPUT_GRAD);
|
||||
auto* dfilter = Output(FILTER_GRAD);
|
||||
auto* dbias = Output(BIAS_GRAD);
|
||||
auto* dbias = Output(BIAS_OR_INPUT_GRAD);
|
||||
const int N = X.dim32(0), H = X.dim32(1), W = X.dim32(2), C = X.dim32(3);
|
||||
ConvPoolOpBase<Context>::ComputePads(H, W);
|
||||
CAFFE_ENFORCE(4 == filter.ndim());
|
||||
|
|
|
|||
|
|
@ -70,6 +70,7 @@ class CNNModelHelper(ModelHelperBase):
|
|||
"""Convolution. We intentionally do not provide odd kernel/stride/pad
|
||||
settings in order to discourage the use of odd cases.
|
||||
"""
|
||||
use_bias = False if ("no_bias" in kwargs and kwargs["no_bias"]) else True
|
||||
weight_init = weight_init if weight_init else ('XavierFill', {})
|
||||
bias_init = bias_init if bias_init else ('ConstantFill', {})
|
||||
blob_out = blob_out or self.net.NextName()
|
||||
|
|
@ -84,27 +85,43 @@ class CNNModelHelper(ModelHelperBase):
|
|||
shape=weight_shape,
|
||||
**weight_init[1]
|
||||
)
|
||||
bias = self.param_init_net.__getattr__(bias_init[0])(
|
||||
[],
|
||||
blob_out + '_b',
|
||||
shape=[dim_out, ],
|
||||
**bias_init[1]
|
||||
)
|
||||
if use_bias:
|
||||
bias = self.param_init_net.__getattr__(bias_init[0])(
|
||||
[],
|
||||
blob_out + '_b',
|
||||
shape=[dim_out, ],
|
||||
**bias_init[1]
|
||||
)
|
||||
else:
|
||||
weight = core.ScopedBlobReference(
|
||||
blob_out + '_w', self.param_init_net)
|
||||
bias = core.ScopedBlobReference(
|
||||
blob_out + '_b', self.param_init_net)
|
||||
self.params.extend([weight, bias])
|
||||
if use_bias:
|
||||
bias = core.ScopedBlobReference(
|
||||
blob_out + '_b', self.param_init_net)
|
||||
if use_bias:
|
||||
self.params.extend([weight, bias])
|
||||
else:
|
||||
self.params.extend([weight])
|
||||
|
||||
self.weights.append(weight)
|
||||
self.biases.append(bias)
|
||||
|
||||
if use_bias:
|
||||
self.biases.append(bias)
|
||||
|
||||
if self.use_cudnn:
|
||||
kwargs['engine'] = 'CUDNN'
|
||||
kwargs['exhaustive_search'] = self.cudnn_exhaustive_search
|
||||
if self.ws_nbytes_limit:
|
||||
kwargs['ws_nbytes_limit'] = self.ws_nbytes_limit
|
||||
|
||||
inputs = []
|
||||
if use_bias:
|
||||
inputs = [blob_in, weight, bias]
|
||||
else:
|
||||
inputs = [blob_in, weight]
|
||||
|
||||
return self.net.Conv(
|
||||
[blob_in, weight, bias],
|
||||
inputs,
|
||||
blob_out,
|
||||
kernel=kernel,
|
||||
order=self.order,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user