mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
In the CUDA path of depthwise_conv2d, add a fast NCHW backward input convolution for images smaller than 16x16.
PiperOrigin-RevId: 158061669
This commit is contained in:
parent
bee26215c9
commit
827874c307
|
|
@ -129,7 +129,7 @@ __global__ void __launch_bounds__(1024, 2)
|
|||
}
|
||||
}
|
||||
|
||||
// CUDA kernel to compute the depthwise convolution forward pass in NCHW format,
|
||||
// CUDA kernel to compute the depthwise convolution forward pass in NHWC format,
|
||||
// tailored for small images up to 16x16. Stride and depth multiplier must be 1.
|
||||
// Padding must be 'SAME', which allows to reuse the index computation. Only
|
||||
// use this kernel if CanLaunchDepthwiseConv2dGPUSmall(args) returns true.
|
||||
|
|
@ -566,7 +566,7 @@ void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& d, const DepthwiseArgs args,
|
|||
<<<config.block_count, block_dim, shared_memory_size, d.stream()>>>(
|
||||
args, input, filter, output);
|
||||
} else {
|
||||
assert(false);
|
||||
assert(false && "Incorrect data format");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -625,7 +625,7 @@ void LaunchDepthwiseConv2dGPU(const GpuDevice& d, const DepthwiseArgs args,
|
|||
config.thread_per_block, 0, d.stream()>>>(args, input, filter,
|
||||
output, num_outputs);
|
||||
} else {
|
||||
assert(false);
|
||||
assert(false && "Incorrect data format");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -714,9 +714,10 @@ __global__ void __launch_bounds__(640, 2)
|
|||
}
|
||||
|
||||
// CUDA kernel to compute the depthwise convolution backward w.r.t. input in
|
||||
// NCHW format, tailored for small images up to 16x16. Stride and depth
|
||||
// NHWC format, tailored for small images up to 16x16. Stride and depth
|
||||
// multiplier must be 1. Padding must be 'SAME', which allows to reuse the index
|
||||
// computation.
|
||||
// computation. Only use this kernel if CanLaunchDepthwiseConv2dGPUSmall(args)
|
||||
// returns true.
|
||||
// Implementation is the same as the forward pass, except that the filter is
|
||||
// rotate by 180°, see filter_read_offset and filter_ptr.
|
||||
// Tiles of the input and filter tensors are loaded into shared memory before
|
||||
|
|
@ -727,6 +728,7 @@ template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
|||
__global__
|
||||
__launch_bounds__(1024, 2) void DepthwiseConv2dBackpropInputGPUKernelNHWCSmall(
|
||||
const DepthwiseArgs args, const T* input, const T* filter, T* output) {
|
||||
assert(CanLaunchDepthwiseConv2dGPUSmall(args));
|
||||
// Holds block plus halo and filter data for blockDim.x depths.
|
||||
extern __shared__ __align__(sizeof(T)) unsigned char shared_memory[];
|
||||
T* const shared_data = reinterpret_cast<T*>(shared_memory);
|
||||
|
|
@ -922,56 +924,209 @@ __global__ void __launch_bounds__(640, 2)
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight>
|
||||
bool TryLaunchDepthwiseConv2dBackpropInputGPUSmall(
|
||||
const GpuDevice& d, const DepthwiseArgs args, const T* out_backprop,
|
||||
const T* filter, T* in_backprop, TensorFormat data_format) {
|
||||
if (data_format != FORMAT_NHWC || args.depth_multiplier != 1 ||
|
||||
args.stride != 1 || args.in_rows > 16 || args.in_cols > 16 ||
|
||||
args.in_rows != args.out_rows || args.in_cols != args.out_cols ||
|
||||
args.pad_rows < 0 || args.pad_rows >= args.filter_rows ||
|
||||
args.pad_cols < 0 || args.pad_cols >= args.filter_cols) {
|
||||
return false;
|
||||
}
|
||||
// CUDA kernel to compute the depthwise convolution backward w.r.t. input in
|
||||
// NHWC format, tailored for small images up to 16x16. Stride and depth
|
||||
// multiplier must be 1. Padding must be 'SAME', which allows to reuse the index
|
||||
// computation. Only use this kernel if CanLaunchDepthwiseConv2dGPUSmall(args)
|
||||
// returns true.
|
||||
// Implementation is the same as the forward pass, except that the filter is
|
||||
// rotate by 180°, see filter_read_offset and filter_ptr.
|
||||
// Tiles of the input and filter tensors are loaded into shared memory before
|
||||
// performing the convolution. Each thread handles two elements per iteration,
|
||||
// one each in the lower and upper half of a tile.
|
||||
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
||||
bool kKnownEvenRows>
|
||||
__global__
|
||||
__launch_bounds__(1024, 2) void DepthwiseConv2dBackpropInputGPUKernelNCHWSmall(
|
||||
const DepthwiseArgs args, const T* input, const T* filter, T* output) {
|
||||
assert(CanLaunchDepthwiseConv2dGPUSmall(args));
|
||||
// Holds block plus halo and filter data for blockDim.z depths.
|
||||
extern __shared__ __align__(sizeof(T)) unsigned char shared_memory[];
|
||||
T* const shared_data = reinterpret_cast<T*>(shared_memory);
|
||||
|
||||
const int batches = args.batch;
|
||||
const int in_rows = args.in_rows;
|
||||
const int in_cols = args.in_cols;
|
||||
const int in_depth = args.in_depth;
|
||||
const int filter_rows =
|
||||
kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
|
||||
const int filter_cols =
|
||||
kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
|
||||
const int pad_rows = args.pad_rows;
|
||||
const int pad_cols = args.pad_cols;
|
||||
|
||||
// Fixed blockDim.z, tailored for maximum grid size for images of size 16x16.
|
||||
const int block_rows = blockDim.y;
|
||||
const int block_slices = 8;
|
||||
|
||||
// These values are the same for all threads and could
|
||||
// be precomputed on the CPU.
|
||||
const int block_pixels = in_cols * block_rows;
|
||||
const int block_size = block_pixels * block_slices;
|
||||
const int in_pixels = in_cols * in_rows;
|
||||
const int in_increment = in_cols - 1;
|
||||
const int filter_pixels = filter_rows * filter_cols;
|
||||
const int tile_cols = in_cols + filter_cols - 1;
|
||||
const int even_rows = kKnownEvenRows || (1 & ~in_rows);
|
||||
const int tile_rows = in_rows + filter_rows - even_rows;
|
||||
const int tile_pixels = tile_cols * tile_rows;
|
||||
const int tile_size = tile_pixels * block_slices;
|
||||
const int tile_offset = block_rows * tile_cols;
|
||||
const int pad_offset = pad_rows * tile_cols + pad_cols;
|
||||
const int in_slices = in_depth * batches;
|
||||
const int in_blocks = (in_slices + block_slices - 1) / block_slices;
|
||||
|
||||
const int thread_col = threadIdx.x;
|
||||
const int thread_row = threadIdx.y;
|
||||
const int thread_depth = threadIdx.z;
|
||||
|
||||
// Position in block.
|
||||
const int thread_pix = thread_row * in_cols + thread_col;
|
||||
const int thread_idx = thread_depth * block_pixels + thread_pix;
|
||||
|
||||
// Initialize tile, in particular the padding.
|
||||
for (int i = thread_idx; i < tile_size; i += block_size) {
|
||||
shared_data[i] = T(0);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Position in tensors.
|
||||
const int tensor_idx = thread_depth * in_pixels + thread_pix;
|
||||
|
||||
// Position in (padded) shared memory.
|
||||
const int data_pix = thread_row * tile_cols + thread_col;
|
||||
const int data_idx = thread_depth * tile_pixels + data_pix;
|
||||
|
||||
// Position in shared memory, offset by pad_rows / pad_cols.
|
||||
const int tile_idx = data_idx + pad_offset;
|
||||
|
||||
// Filter is always in HWCK format, irrespective of the input/output format.
|
||||
const int filter_pix = thread_idx / block_slices;
|
||||
const int filter_depth = thread_idx % block_slices;
|
||||
const int filter_idx = filter_pix * in_depth;
|
||||
|
||||
const int max_slice = in_slices - thread_depth;
|
||||
const int filter_write_offset =
|
||||
filter_pix < filter_pixels ? tile_size + thread_idx : 0;
|
||||
const int filter_read_offset =
|
||||
tile_size + filter_pixels * block_slices + thread_depth;
|
||||
const bool skip_second =
|
||||
!kKnownEvenRows && thread_row + (in_rows & 1) == block_rows;
|
||||
|
||||
for (int b = blockIdx.x; b < in_blocks; b += gridDim.x) {
|
||||
const int slice = b * block_slices;
|
||||
|
||||
const int inout_offset = slice * in_pixels + tensor_idx;
|
||||
const bool slice_in_range = slice < max_slice;
|
||||
|
||||
if (slice_in_range) {
|
||||
const T* const in_ptr = inout_offset + input;
|
||||
T* const tile_ptr = tile_idx + shared_data;
|
||||
tile_ptr[0] = ldg(in_ptr);
|
||||
if (!skip_second) {
|
||||
tile_ptr[tile_offset] = ldg(block_pixels + in_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
if (filter_write_offset != 0) {
|
||||
const int filter_offset = filter_idx + (slice + filter_depth) % in_depth;
|
||||
shared_data[filter_write_offset] = ldg(filter_offset + filter);
|
||||
}
|
||||
|
||||
// Note: the condition to reach this is uniform across the entire block.
|
||||
__syncthreads();
|
||||
|
||||
if (slice_in_range) {
|
||||
T sum1 = 0;
|
||||
T sum2 = 0;
|
||||
int shared_offset = data_idx;
|
||||
const T* filter_ptr = filter_read_offset + shared_data;
|
||||
UNROLL for (int r = 0; r < filter_rows; ++r) {
|
||||
UNROLL for (int c = 0; c < filter_cols; ++c) {
|
||||
filter_ptr -= block_slices;
|
||||
const T filter_value = *filter_ptr;
|
||||
const T* const tile_ptr = shared_offset + shared_data;
|
||||
sum1 += filter_value * tile_ptr[0];
|
||||
sum2 += filter_value * tile_ptr[tile_offset];
|
||||
++shared_offset;
|
||||
}
|
||||
shared_offset += in_increment;
|
||||
}
|
||||
T* const out_ptr = inout_offset + output;
|
||||
out_ptr[0] = sum1;
|
||||
if (!skip_second) {
|
||||
out_ptr[block_pixels] = sum2;
|
||||
}
|
||||
}
|
||||
|
||||
// Note: the condition to reach this is uniform across the entire block.
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
||||
bool kKnownEvenRows>
|
||||
void LaunchDepthwiseConv2dBackpropInputGPUSmall(const GpuDevice& d,
|
||||
const DepthwiseArgs args,
|
||||
const T* out_backprop,
|
||||
const T* filter, T* in_backprop,
|
||||
TensorFormat data_format) {
|
||||
const int block_rows = (args.in_rows + 1) / 2;
|
||||
if (args.filter_rows * args.filter_cols > args.in_cols * block_rows) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int block_slices = 8;
|
||||
const int tile_cols = args.in_cols + args.filter_cols - 1;
|
||||
const int tile_rows = block_rows * 2 + args.filter_rows - 1;
|
||||
const int tile_pixels = tile_rows * tile_cols;
|
||||
const int filter_pixels = args.filter_rows * args.filter_cols;
|
||||
dim3 block_dim = dim3(8, args.in_cols, block_rows);
|
||||
const int shared_memory_size =
|
||||
block_dim.x * (tile_pixels + filter_pixels) * sizeof(T);
|
||||
|
||||
const int num_in_backprop =
|
||||
args.batch * args.in_rows * args.in_cols * args.in_depth;
|
||||
if (args.in_rows & 1) {
|
||||
const int shared_memory_size =
|
||||
block_slices * (tile_pixels + filter_pixels) * sizeof(T);
|
||||
const int num_outputs =
|
||||
args.batch * args.out_rows * args.out_cols * args.out_depth;
|
||||
|
||||
if (data_format == FORMAT_NHWC) {
|
||||
dim3 block_dim = dim3(block_slices, args.in_cols, block_rows);
|
||||
CudaLaunchConfig config = GetCudaLaunchConfig(
|
||||
num_in_backprop, d,
|
||||
num_outputs, d,
|
||||
DepthwiseConv2dBackpropInputGPUKernelNHWCSmall<
|
||||
T, kKnownFilterWidth, kKnownFilterHeight, false>,
|
||||
T, kKnownFilterWidth, kKnownFilterHeight, kKnownEvenRows>,
|
||||
shared_memory_size, block_dim.x * block_dim.y * block_dim.z);
|
||||
DepthwiseConv2dBackpropInputGPUKernelNHWCSmall<T, kKnownFilterWidth,
|
||||
kKnownFilterHeight, false>
|
||||
DepthwiseConv2dBackpropInputGPUKernelNHWCSmall<
|
||||
T, kKnownFilterWidth, kKnownFilterHeight, kKnownEvenRows>
|
||||
<<<config.block_count, block_dim, shared_memory_size, d.stream()>>>(
|
||||
args, out_backprop, filter, in_backprop);
|
||||
} else if (data_format == FORMAT_NCHW) {
|
||||
dim3 block_dim = dim3(args.in_cols, block_rows, block_slices);
|
||||
CudaLaunchConfig config = GetCudaLaunchConfig(
|
||||
num_outputs, d,
|
||||
DepthwiseConv2dBackpropInputGPUKernelNCHWSmall<
|
||||
T, kKnownFilterWidth, kKnownFilterHeight, kKnownEvenRows>,
|
||||
shared_memory_size, block_dim.x * block_dim.y * block_dim.z);
|
||||
DepthwiseConv2dBackpropInputGPUKernelNCHWSmall<
|
||||
T, kKnownFilterWidth, kKnownFilterHeight, kKnownEvenRows>
|
||||
<<<config.block_count, block_dim, shared_memory_size, d.stream()>>>(
|
||||
args, out_backprop, filter, in_backprop);
|
||||
} else {
|
||||
CudaLaunchConfig config = GetCudaLaunchConfig(
|
||||
num_in_backprop, d,
|
||||
DepthwiseConv2dBackpropInputGPUKernelNHWCSmall<
|
||||
T, kKnownFilterWidth, kKnownFilterHeight, true>,
|
||||
shared_memory_size, block_dim.x * block_dim.y * block_dim.z);
|
||||
DepthwiseConv2dBackpropInputGPUKernelNHWCSmall<T, kKnownFilterWidth,
|
||||
kKnownFilterHeight, true>
|
||||
<<<config.block_count, block_dim, shared_memory_size, d.stream()>>>(
|
||||
args, out_backprop, filter, in_backprop);
|
||||
assert(false && "Incorrect data format");
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight>
|
||||
void LaunchDepthwiseConv2dBackpropInputGPUSmall(const GpuDevice& d,
|
||||
const DepthwiseArgs args,
|
||||
const T* out_backprop,
|
||||
const T* filter, T* in_backprop,
|
||||
TensorFormat data_format) {
|
||||
if (args.in_rows & 1) {
|
||||
LaunchDepthwiseConv2dBackpropInputGPUSmall<T, kKnownFilterWidth,
|
||||
kKnownFilterHeight,
|
||||
/*kKnownEvenRows=*/false>(
|
||||
d, args, out_backprop, filter, in_backprop, data_format);
|
||||
} else {
|
||||
LaunchDepthwiseConv2dBackpropInputGPUSmall<T, kKnownFilterWidth,
|
||||
kKnownFilterHeight,
|
||||
/*kKnownEvenRows=*/true>(
|
||||
d, args, out_backprop, filter, in_backprop, data_format);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
||||
|
|
@ -981,9 +1136,10 @@ void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& d,
|
|||
const T* out_backprop,
|
||||
const T* filter, T* in_backprop,
|
||||
TensorFormat data_format) {
|
||||
if (TryLaunchDepthwiseConv2dBackpropInputGPUSmall<T, kKnownFilterWidth,
|
||||
kKnownFilterHeight>(
|
||||
d, args, out_backprop, filter, in_backprop, data_format)) {
|
||||
if (CanLaunchDepthwiseConv2dGPUSmall(args)) {
|
||||
LaunchDepthwiseConv2dBackpropInputGPUSmall<T, kKnownFilterWidth,
|
||||
kKnownFilterHeight>(
|
||||
d, args, out_backprop, filter, in_backprop, data_format);
|
||||
return;
|
||||
}
|
||||
const int num_in_backprop =
|
||||
|
|
@ -1009,7 +1165,7 @@ void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& d,
|
|||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
args, out_backprop, filter, in_backprop, num_in_backprop);
|
||||
} else {
|
||||
assert(false);
|
||||
assert(false && "Incorrect data format");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1462,7 +1618,7 @@ void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& d,
|
|||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
args, out_backprop, input, filter_backprop, num_out_backprop);
|
||||
} else {
|
||||
assert(false);
|
||||
assert(false && "Incorrect data format");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user