mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Add V2 versions of output window size computation functions for convolution.
These V2 versions take arbitrary dilation rates. In preparation for the support of native cudnn dilated convolution. PiperOrigin-RevId: 171048878
This commit is contained in:
parent
491584ff4d
commit
cf17ec96ed
|
|
@ -17,24 +17,31 @@ limitations under the License.
|
|||
|
||||
namespace tensorflow {
|
||||
|
||||
Status GetWindowedOutputSizeVerbose(int64 input_size, int64 filter_size,
|
||||
int64 stride, Padding padding_type,
|
||||
int64* output_size, int64* padding_before,
|
||||
Status GetWindowedOutputSizeVerboseV2(int64 input_size, int64 filter_size,
|
||||
int64 dilation_rate, int64 stride,
|
||||
Padding padding_type, int64* output_size,
|
||||
int64* padding_before,
|
||||
int64* padding_after) {
|
||||
if (stride <= 0) {
|
||||
return errors::InvalidArgument("Stride must be > 0, but got ", stride);
|
||||
}
|
||||
if (dilation_rate < 1) {
|
||||
return errors::InvalidArgument("Dilation rate must be >= 1, but got ",
|
||||
dilation_rate);
|
||||
}
|
||||
|
||||
// See also the parallel implementation in GetWindowedOutputSizeFromDims.
|
||||
// See also the parallel implementation in GetWindowedOutputSizeFromDimsV2.
|
||||
int64 effective_filter_size = (filter_size - 1) * dilation_rate + 1;
|
||||
switch (padding_type) {
|
||||
case Padding::VALID:
|
||||
*output_size = (input_size - filter_size + stride) / stride;
|
||||
*output_size = (input_size - effective_filter_size + stride) / stride;
|
||||
*padding_before = *padding_after = 0;
|
||||
break;
|
||||
case Padding::SAME:
|
||||
*output_size = (input_size + stride - 1) / stride;
|
||||
const int64 padding_needed =
|
||||
std::max(0LL, (*output_size - 1) * stride + filter_size - input_size);
|
||||
std::max(0LL, (*output_size - 1) * stride + effective_filter_size -
|
||||
input_size);
|
||||
// For odd values of total padding, add more padding at the 'right'
|
||||
// side of the given dimension.
|
||||
*padding_before = padding_needed / 2;
|
||||
|
|
@ -47,15 +54,35 @@ Status GetWindowedOutputSizeVerbose(int64 input_size, int64 filter_size,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetWindowedOutputSizeVerbose(int64 input_size, int64 filter_size,
|
||||
int64 stride, Padding padding_type,
|
||||
int64* output_size, int64* padding_before,
|
||||
int64* padding_after) {
|
||||
return GetWindowedOutputSizeVerboseV2(input_size, filter_size,
|
||||
/*dilation_rate=*/1, stride,
|
||||
padding_type, output_size,
|
||||
padding_before, padding_after);
|
||||
}
|
||||
|
||||
Status GetWindowedOutputSize(int64 input_size, int64 filter_size, int64 stride,
|
||||
Padding padding_type, int64* output_size,
|
||||
int64* padding) {
|
||||
int64* padding_size) {
|
||||
int64 padding_after_unused;
|
||||
return GetWindowedOutputSizeVerbose(input_size, filter_size, stride,
|
||||
padding_type, output_size, padding,
|
||||
padding_type, output_size, padding_size,
|
||||
&padding_after_unused);
|
||||
}
|
||||
|
||||
Status GetWindowedOutputSizeV2(int64 input_size, int64 filter_size,
|
||||
int64 dilation_rate, int64 stride,
|
||||
Padding padding_type, int64* output_size,
|
||||
int64* padding_size) {
|
||||
int64 padding_after_unused;
|
||||
return GetWindowedOutputSizeVerboseV2(input_size, filter_size, dilation_rate,
|
||||
stride, padding_type, output_size,
|
||||
padding_size, &padding_after_unused);
|
||||
}
|
||||
|
||||
Status Get3dOutputSize(const std::array<int64, 3>& input,
|
||||
const std::array<int64, 3>& window,
|
||||
const std::array<int64, 3>& strides,
|
||||
|
|
@ -69,32 +96,75 @@ Status Get3dOutputSize(const std::array<int64, 3>& input,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Get3dOutputSizeV2(const std::array<int64, 3>& input,
|
||||
const std::array<int64, 3>& window,
|
||||
const std::array<int64, 3>& dilations,
|
||||
const std::array<int64, 3>& strides,
|
||||
Padding padding_type, std::array<int64, 3>* output_ptr,
|
||||
std::array<int64, 3>* padding_ptr) {
|
||||
for (size_t i = 0; i < input.size(); ++i) {
|
||||
TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2(
|
||||
input[i], window[i], dilations[i], strides[i], padding_type,
|
||||
&(*output_ptr)[i], &(*padding_ptr)[i]));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace shape_inference {
|
||||
|
||||
// The V2 version computes windowed output size with arbitrary dilation_rate,
|
||||
// while the original version only handles the cases where dilation_rates equal
|
||||
// to 1.
|
||||
Status GetWindowedOutputSizeFromDimsV2(
|
||||
shape_inference::InferenceContext* c,
|
||||
shape_inference::DimensionHandle input_size,
|
||||
shape_inference::DimensionOrConstant filter_size, int64 dilation_rate,
|
||||
int64 stride, Padding padding_type,
|
||||
shape_inference::DimensionHandle* output_size) {
|
||||
if (stride <= 0) {
|
||||
return errors::InvalidArgument("Stride must be > 0, but got ", stride);
|
||||
}
|
||||
|
||||
if (dilation_rate < 1) {
|
||||
return errors::InvalidArgument("Dilation rate must be >= 1, but got ",
|
||||
dilation_rate);
|
||||
}
|
||||
|
||||
// See also the parallel implementation in GetWindowedOutputSizeVerbose.
|
||||
switch (padding_type) {
|
||||
case Padding::VALID:
|
||||
if (dilation_rate > 1) {
|
||||
DimensionHandle window_size;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Subtract(c->MakeDim(filter_size), 1, &window_size));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Multiply(window_size, dilation_rate, &window_size));
|
||||
TF_RETURN_IF_ERROR(c->Add(window_size, 1, &window_size));
|
||||
TF_RETURN_IF_ERROR(c->Subtract(input_size, window_size, output_size));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(c->Subtract(input_size, filter_size, output_size));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(c->Add(*output_size, stride, output_size));
|
||||
TF_RETURN_IF_ERROR(c->Divide(*output_size, stride,
|
||||
/*evenly_divisible=*/false, output_size));
|
||||
break;
|
||||
case Padding::SAME:
|
||||
TF_RETURN_IF_ERROR(c->Add(input_size, stride - 1, output_size));
|
||||
TF_RETURN_IF_ERROR(c->Divide(*output_size, stride,
|
||||
/*evenly_divisible=*/false, output_size));
|
||||
break;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetWindowedOutputSizeFromDims(
|
||||
shape_inference::InferenceContext* c,
|
||||
shape_inference::DimensionHandle input_size,
|
||||
shape_inference::DimensionOrConstant filter_size, int64 stride,
|
||||
Padding padding_type, shape_inference::DimensionHandle* output_size) {
|
||||
if (stride <= 0) {
|
||||
return errors::InvalidArgument("Stride must be > 0, but got ", stride);
|
||||
}
|
||||
|
||||
// See also the parallel implementation in GetWindowedOutputSizeVerbose.
|
||||
switch (padding_type) {
|
||||
case Padding::VALID:
|
||||
TF_RETURN_IF_ERROR(c->Subtract(input_size, filter_size, output_size));
|
||||
TF_RETURN_IF_ERROR(c->Add(*output_size, stride, output_size));
|
||||
TF_RETURN_IF_ERROR(c->Divide(*output_size, stride,
|
||||
false /* evenly_divisible */, output_size));
|
||||
break;
|
||||
case Padding::SAME:
|
||||
TF_RETURN_IF_ERROR(c->Add(input_size, stride - 1, output_size));
|
||||
TF_RETURN_IF_ERROR(c->Divide(*output_size, stride,
|
||||
false /* evenly_divisible */, output_size));
|
||||
break;
|
||||
}
|
||||
return Status::OK();
|
||||
return GetWindowedOutputSizeFromDimsV2(c, input_size, filter_size,
|
||||
/*dilation_rate=*/1, stride,
|
||||
padding_type, output_size);
|
||||
}
|
||||
|
||||
Status UnchangedShape(shape_inference::InferenceContext* c) {
|
||||
|
|
|
|||
|
|
@ -75,6 +75,32 @@ Status GetWindowedOutputSize(int64 input_size, int64 filter_size, int64 stride,
|
|||
Padding padding_type, int64* output_size,
|
||||
int64* padding_size);
|
||||
|
||||
// The V2 version computes the same outputs with arbitrary dilation_rate.
|
||||
// The output dimensions are computed as follows:
|
||||
// - When adding dilation_rate (D), we compute an effective filter size (K'):
|
||||
// K' = (K - 1) * D + 1
|
||||
// - When Padding = SAME: the output size is (H'), where
|
||||
// H' = ceil(float(H) / float(S))
|
||||
// where ceil is the ceiling function. The number of padded cells
|
||||
// is computed as:
|
||||
// Pc = ((H' - 1) * S + K' - H) / 2
|
||||
// When the stride is 1, the expression simplifies to
|
||||
// H' = H, Pc = (K'-1)/2.
|
||||
// This is where SAME comes from - the output has the same size as the input
|
||||
// has.
|
||||
//
|
||||
// - When Padding = VALID: the output size is computed as
|
||||
// H' = ceil(float(H - K' + 1) / float(S))
|
||||
// and the number of padded cells is always zero.
|
||||
// When the stride is 1, the expression simplifies to
|
||||
// H' = H-K'+1.
|
||||
//
|
||||
// TODO(b/67112639): Merge V2 versions and the original versions eventually.
|
||||
Status GetWindowedOutputSizeV2(int64 input_size, int64 filter_size,
|
||||
int64 dilation_rate, int64 stride,
|
||||
Padding padding_type, int64* output_size,
|
||||
int64* padding_size);
|
||||
|
||||
// Returns the same output dimensions as in GetWindowedOutputSize, but returns
|
||||
// verbose padding dimensions (before/after). Any excess padding
|
||||
// (caused by an odd padding size value) is added to the 'padding_after'
|
||||
|
|
@ -84,6 +110,14 @@ Status GetWindowedOutputSizeVerbose(int64 input_size, int64 filter_size,
|
|||
int64* output_size, int64* padding_before,
|
||||
int64* padding_after);
|
||||
|
||||
// The V2 version computes the same outputs with arbitrary dilation_rate. For
|
||||
// detailed equations, refer to the comments for GetWindowedOutputSizeV2().
|
||||
Status GetWindowedOutputSizeVerboseV2(int64 input_size, int64 filter_size,
|
||||
int64 dilation_rate, int64 stride,
|
||||
Padding padding_type, int64* output_size,
|
||||
int64* padding_before,
|
||||
int64* padding_after);
|
||||
|
||||
// Given an input tensor, kernel, stride and padding type, populates the 3D size
|
||||
// of the output tensor and padding to be applied to the input tensor at the
|
||||
// lower end of every dimension. Use for 3D convolutions, where the input data
|
||||
|
|
@ -92,8 +126,17 @@ Status GetWindowedOutputSizeVerbose(int64 input_size, int64 filter_size,
|
|||
Status Get3dOutputSize(const std::array<int64, 3>& input,
|
||||
const std::array<int64, 3>& window,
|
||||
const std::array<int64, 3>& strides,
|
||||
Padding padding_type, std::array<int64, 3>* output,
|
||||
std::array<int64, 3>* padding);
|
||||
Padding padding_type, std::array<int64, 3>* output_ptr,
|
||||
std::array<int64, 3>* padding_ptr);
|
||||
|
||||
// The V2 version computes the same outputs with arbitrary dilation_rate. For
|
||||
// detailed equations, refer to the comments for GetWindowedOutputSizeV2().
|
||||
Status Get3dOutputSizeV2(const std::array<int64, 3>& input,
|
||||
const std::array<int64, 3>& window,
|
||||
const std::array<int64, 3>& dilations,
|
||||
const std::array<int64, 3>& strides,
|
||||
Padding padding_type, std::array<int64, 3>* output_ptr,
|
||||
std::array<int64, 3>* padding_ptr);
|
||||
|
||||
namespace shape_inference {
|
||||
|
||||
|
|
@ -104,6 +147,15 @@ Status GetWindowedOutputSizeFromDims(InferenceContext* c,
|
|||
int64 stride, Padding padding_type,
|
||||
DimensionHandle* output_size);
|
||||
|
||||
// The V2 version computes the same outputs with arbitrary dilation_rate. For
|
||||
// detailed equations, refer to the comments for GetWindowedOutputSizeV2().
|
||||
Status GetWindowedOutputSizeFromDimsV2(InferenceContext* c,
|
||||
DimensionHandle input_size,
|
||||
DimensionOrConstant filter_size,
|
||||
int64 dilation_rate, int64 stride,
|
||||
Padding padding_type,
|
||||
DimensionHandle* output_size);
|
||||
|
||||
// Transfers shape of input(0) to output(0).
|
||||
Status UnchangedShape(shape_inference::InferenceContext* c);
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
|
|
@ -40,46 +41,64 @@ limitations under the License.
|
|||
|
||||
namespace tensorflow {
|
||||
|
||||
// The V2 version computes windowed output size with arbitrary dilation_rate,
|
||||
// while the original version only handles the cases where dilation_rates equal
|
||||
// to 1.
|
||||
Status ConvBackpropExtractAndVerifyDimensionV2(
|
||||
StringPiece label, const TensorShape& input_shape,
|
||||
const TensorShape& filter_shape, const TensorShape& output_shape,
|
||||
const gtl::ArraySlice<int32>& dilations, const std::vector<int32>& strides,
|
||||
Padding padding, int spatial_dim, int filter_spatial_dim,
|
||||
ConvBackpropSpatialDimension* dim) {
|
||||
dim->input_size = input_shape.dim_size(spatial_dim);
|
||||
dim->filter_size = filter_shape.dim_size(filter_spatial_dim);
|
||||
dim->output_size = output_shape.dim_size(spatial_dim);
|
||||
dim->stride = strides[spatial_dim];
|
||||
dim->dilation = dilations[spatial_dim];
|
||||
int64 out_size = 0, pad_size = 0;
|
||||
TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2(dim->input_size, dim->filter_size,
|
||||
dim->dilation, dim->stride,
|
||||
padding, &out_size, &pad_size));
|
||||
if (dim->output_size != out_size) {
|
||||
return errors::InvalidArgument(
|
||||
label, ": Size of out_backprop doesn't match computed: ", "actual = ",
|
||||
dim->output_size, ", computed = ", out_size,
|
||||
"spatial_dim: ", spatial_dim, " input: ", dim->input_size,
|
||||
" filter: ", dim->filter_size, " output: ", dim->output_size,
|
||||
" stride: ", dim->stride, " dilation: ", dim->dilation);
|
||||
}
|
||||
|
||||
int64 effective_filter_size = (dim->filter_size - 1) * dim->dilation + 1;
|
||||
dim->expanded_output_size = (dim->output_size - 1) * dim->stride + 1;
|
||||
const auto padded_out_size = dim->input_size + effective_filter_size - 1;
|
||||
dim->pad_before = effective_filter_size - 1 - pad_size;
|
||||
dim->pad_after =
|
||||
padded_out_size - dim->expanded_output_size - dim->pad_before;
|
||||
VLOG(2) << label << ": expanded_out = " << dim->expanded_output_size
|
||||
<< ", effective_filter_size = " << effective_filter_size
|
||||
<< ", padded_out = " << padded_out_size
|
||||
<< ", pad_before = " << dim->pad_before
|
||||
<< ", pad_after = " << dim->pad_after
|
||||
<< ", dilation = " << dim->dilation << ", strides = " << dim->stride;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConvBackpropExtractAndVerifyDimension(
|
||||
StringPiece label, const TensorShape& input_shape,
|
||||
const TensorShape& filter_shape, const TensorShape& output_shape,
|
||||
const std::vector<int32>& strides, Padding padding, int spatial_dim,
|
||||
int filter_spatial_dim, ConvBackpropSpatialDimension* dim) {
|
||||
dim->input_size = input_shape.dim_size(spatial_dim);
|
||||
dim->filter_size = filter_shape.dim_size(filter_spatial_dim);
|
||||
dim->output_size = output_shape.dim_size(spatial_dim);
|
||||
dim->stride = strides[spatial_dim];
|
||||
int64 out_size = 0, pad_size = 0;
|
||||
TF_RETURN_IF_ERROR(GetWindowedOutputSize(dim->input_size, dim->filter_size,
|
||||
dim->stride, padding, &out_size,
|
||||
&pad_size));
|
||||
if (dim->output_size != out_size) {
|
||||
return errors::InvalidArgument(
|
||||
label, ": Size of out_backprop doesn't match computed: ", "actual = ",
|
||||
dim->output_size, ", computed = ", out_size);
|
||||
}
|
||||
|
||||
dim->expanded_output_size = (dim->output_size - 1) * dim->stride + 1;
|
||||
const auto padded_out_size = dim->input_size + dim->filter_size - 1;
|
||||
dim->pad_before = dim->filter_size - 1 - pad_size;
|
||||
dim->pad_after =
|
||||
padded_out_size - dim->expanded_output_size - dim->pad_before;
|
||||
VLOG(2) << label << ": expanded_out = " << dim->expanded_output_size
|
||||
<< ", filter = " << dim->filter_size
|
||||
<< ", padded_out = " << padded_out_size
|
||||
<< ", pad_before = " << dim->pad_before
|
||||
<< ", pad_after = " << dim->pad_after
|
||||
<< ", strides = " << dim->stride;
|
||||
return Status::OK();
|
||||
static constexpr std::array<int32, 5> one_dilations = {{1, 1, 1, 1, 1}};
|
||||
return ConvBackpropExtractAndVerifyDimensionV2(
|
||||
label, input_shape, filter_shape, output_shape, one_dilations, strides,
|
||||
padding, spatial_dim, filter_spatial_dim, dim);
|
||||
}
|
||||
|
||||
Status ConvBackpropComputeDimensions(StringPiece label, int num_spatial_dims,
|
||||
const TensorShape& input_shape,
|
||||
const TensorShape& filter_shape,
|
||||
const TensorShape& out_backprop_shape,
|
||||
const std::vector<int32>& strides,
|
||||
Padding padding, TensorFormat data_format,
|
||||
ConvBackpropDimensions* dims) {
|
||||
Status ConvBackpropComputeDimensionsV2(
|
||||
StringPiece label, int num_spatial_dims, const TensorShape& input_shape,
|
||||
const TensorShape& filter_shape, const TensorShape& out_backprop_shape,
|
||||
const gtl::ArraySlice<int32>& dilations, const std::vector<int32>& strides,
|
||||
Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims) {
|
||||
// The + 2 in the following line is for the batch and feature dimensions.
|
||||
const int num_dims = num_spatial_dims + 2;
|
||||
if (input_shape.dims() != num_dims) {
|
||||
|
|
@ -98,7 +117,10 @@ Status ConvBackpropComputeDimensions(StringPiece label, int num_spatial_dims,
|
|||
dims->batch_size = input_shape.dim_size(batch_dim);
|
||||
if (dims->batch_size != out_backprop_shape.dim_size(batch_dim)) {
|
||||
return errors::InvalidArgument(
|
||||
label, ": input and out_backprop must have the same batch size");
|
||||
label, ": input and out_backprop must have the same batch size",
|
||||
"input batch: ", dims->batch_size,
|
||||
"outbackprop batch: ", out_backprop_shape.dim_size(batch_dim),
|
||||
" batch_dim: ", batch_dim);
|
||||
}
|
||||
|
||||
int feature_dim = GetTensorFeatureDimIndex(num_dims, data_format);
|
||||
|
|
@ -118,11 +140,24 @@ Status ConvBackpropComputeDimensions(StringPiece label, int num_spatial_dims,
|
|||
dims->spatial_dims.resize(num_spatial_dims);
|
||||
for (int i = 0; i < num_spatial_dims; ++i) {
|
||||
int image_dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
|
||||
TF_RETURN_IF_ERROR(ConvBackpropExtractAndVerifyDimension(
|
||||
label, input_shape, filter_shape, out_backprop_shape, strides, padding,
|
||||
image_dim, i, &dims->spatial_dims[i]));
|
||||
TF_RETURN_IF_ERROR(ConvBackpropExtractAndVerifyDimensionV2(
|
||||
label, input_shape, filter_shape, out_backprop_shape, dilations,
|
||||
strides, padding, image_dim, i, &dims->spatial_dims[i]));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConvBackpropComputeDimensions(StringPiece label, int num_spatial_dims,
|
||||
const TensorShape& input_shape,
|
||||
const TensorShape& filter_shape,
|
||||
const TensorShape& out_backprop_shape,
|
||||
const std::vector<int32>& strides,
|
||||
Padding padding, TensorFormat data_format,
|
||||
ConvBackpropDimensions* dims) {
|
||||
static constexpr std::array<int32, 5> one_dilations = {{1, 1, 1, 1, 1}};
|
||||
return ConvBackpropComputeDimensionsV2(
|
||||
label, num_spatial_dims, input_shape, filter_shape, out_backprop_shape,
|
||||
one_dilations, strides, padding, data_format, dims);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
|||
|
|
@ -212,6 +212,7 @@ struct ConvBackpropSpatialDimension {
|
|||
int64 filter_size;
|
||||
int64 output_size;
|
||||
int64 stride;
|
||||
int64 dilation;
|
||||
int64 expanded_output_size;
|
||||
|
||||
// Number of padding elements to be added before/after this dimension of
|
||||
|
|
@ -242,6 +243,13 @@ Status ConvBackpropComputeDimensions(StringPiece label, int num_spatial_dims,
|
|||
Padding padding, TensorFormat data_format,
|
||||
ConvBackpropDimensions* dims);
|
||||
|
||||
// The V2 version computes the same outputs with arbitrary dilation rate.
|
||||
// TODO(b/67112639): Merge V2 versions and the original versions eventually.
|
||||
Status ConvBackpropComputeDimensionsV2(
|
||||
StringPiece label, int num_spatial_dims, const TensorShape& input_shape,
|
||||
const TensorShape& filter_shape, const TensorShape& out_backprop_shape,
|
||||
const std::vector<int32>& dilations, const std::vector<int32>& strides,
|
||||
Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_CONV_GRAD_OPS_H_
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user