Merge commit '73d15cf64320b4b77e7393efa1bf1e913404cfd6'

This commit is contained in:
soumith 2016-09-13 11:16:09 -07:00
commit 05fb544f23
10 changed files with 1080 additions and 604 deletions

View File

@ -0,0 +1,154 @@
#include "THCUNN.h"
#include "common.h"
#include <thrust/functional.h>
#include <thrust/device_ptr.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/transform.h>
#include <thrust/transform_reduce.h>
const float eps = 1e-12f;
struct bce_functor
{
template <class Tuple>
__host__ __device__
float operator()(Tuple x)
{
float o = thrust::get<0>(x);
float t = thrust::get<1>(x);
return - (t * logf(o + eps) + (1.f - t) * logf(1.f - o + eps));
}
};
struct bce_functor_weights
{
template <class Tuple>
__host__ __device__
float operator()(Tuple x)
{
float o = thrust::get<0>(x);
float t = thrust::get<1>(x);
float w = thrust::get<2>(x);
return - w * (t * logf(o + eps) + (1.f - t) * logf(1.f - o + eps));
}
};
void THNN_CudaBCECriterion_updateOutput(THCState *state, THCudaTensor *input, THCudaTensor *target, THCudaTensor *output, bool sizeAverage, THCudaTensor *weights)
{
THCUNN_assertSameGPU(state, 3, input, target, weights);
long size = THCudaTensor_nElement(state, input);
input = THCudaTensor_newContiguous(state, input);
target = THCudaTensor_newContiguous(state, target);
thrust::device_ptr<float> input_data(THCudaTensor_data(state, input));
thrust::device_ptr<float> target_data(THCudaTensor_data(state, target));
float sum;
if (weights) {
weights = THCudaTensor_newContiguous(state, weights);
thrust::device_ptr<float> weights_data(THCudaTensor_data(state, weights));
sum = thrust::transform_reduce(
thrust::make_zip_iterator(thrust::make_tuple(input_data, target_data, weights_data)),
thrust::make_zip_iterator(thrust::make_tuple(input_data+size, target_data+size, weights_data+size)),
bce_functor_weights(),
(float) 0.f,
thrust::plus<float>()
);
THCudaTensor_free(state, weights);
} else {
sum = thrust::transform_reduce(
thrust::make_zip_iterator(thrust::make_tuple(input_data, target_data)),
thrust::make_zip_iterator(thrust::make_tuple(input_data+size, target_data+size)),
bce_functor(),
(float) 0.f,
thrust::plus<float>()
);
}
if (sizeAverage)
sum /= size;
THCudaTensor_free(state, input);
THCudaTensor_free(state, target);
THCudaTensor_set1d(state, output, 0, sum);
}
struct bce_updateGradInput_functor
{
const float norm;
bce_updateGradInput_functor(float norm_)
: norm(norm_)
{}
template <class Tuple>
__host__ __device__
float operator()(Tuple x)
{
float o = thrust::get<0>(x);
float t = thrust::get<1>(x);
return - (t - o) / ((1 - o + eps) * (o + eps)) * norm;
}
};
struct bce_updateGradInput_functor_weights
{
const float norm;
bce_updateGradInput_functor_weights(float norm_)
: norm(norm_)
{}
template <class Tuple>
__host__ __device__
float operator()(Tuple x)
{
float o = thrust::get<0>(x);
float t = thrust::get<1>(x);
float w = thrust::get<2>(x);
return - (t - o) / ((1 - o + eps) * (o + eps)) * norm * w;
}
};
void THNN_CudaBCECriterion_updateGradInput(THCState *state, THCudaTensor *input, THCudaTensor *target, THCudaTensor *gradInput, bool sizeAverage, THCudaTensor *weights)
{
THCUNN_assertSameGPU(state, 4, input, target, gradInput, weights);
long size = THCudaTensor_nElement(state, input);
float norm = (sizeAverage ? 1./size : 1.);
input = THCudaTensor_newContiguous(state, input);
target = THCudaTensor_newContiguous(state, target);
THCudaTensor_resizeAs(state, gradInput, input);
thrust::device_ptr<float> input_data(THCudaTensor_data(state, input));
thrust::device_ptr<float> target_data(THCudaTensor_data(state, target));
thrust::device_ptr<float> gradInput_data(THCudaTensor_data(state, gradInput));
if (weights) {
weights = THCudaTensor_newContiguous(state, weights);
thrust::device_ptr<float> weights_data(THCudaTensor_data(state, weights));
thrust::transform(
thrust::make_zip_iterator(thrust::make_tuple(input_data, target_data, weights_data)),
thrust::make_zip_iterator(thrust::make_tuple(input_data+size, target_data+size, weights_data+size)),
gradInput_data,
bce_updateGradInput_functor_weights(norm)
);
THCudaTensor_free(state, weights);
} else {
thrust::transform(
thrust::make_zip_iterator(thrust::make_tuple(input_data, target_data)),
thrust::make_zip_iterator(thrust::make_tuple(input_data+size, target_data+size)),
gradInput_data,
bce_updateGradInput_functor(norm)
);
}
THCudaTensor_free(state, input);
THCudaTensor_free(state, target);
}

View File

@ -9,6 +9,13 @@ IF(NOT CUDA_FOUND)
FIND_PACKAGE(CUDA 6.5 REQUIRED)
ENDIF()
# Detect CUDA architecture and get best NVCC flags
IF(NOT COMMAND CUDA_SELECT_NVCC_ARCH_FLAGS)
INCLUDE(${CMAKE_CURRENT_SOURCE_DIR}/cmake/select_compute_arch.cmake)
ENDIF()
CUDA_SELECT_NVCC_ARCH_FLAGS(NVCC_FLAGS_EXTRA $ENV{TORCH_CUDA_ARCH_LIST})
LIST(APPEND CUDA_NVCC_FLAGS ${NVCC_FLAGS_EXTRA})
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "4.9.3")
if(CUDA_VERSION VERSION_LESS "8.0")

View File

@ -89,15 +89,15 @@ void THNN_CudaHardTanh_updateGradInput(
THCUNN_assertSameGPU(state, 3, input, gradOutput, gradInput);
if (inplace)
{
THCudaTensor_resizeAs(state, gradInput, input);
THC_pointwiseApply3(state, gradInput, input, gradOutput,
hardtanhupdateGradInput_functor(min_val, max_val));
}
else
{
THCudaTensor_set(state, gradInput, gradOutput);
THC_pointwiseApply2(state, gradInput, input,
hardtanhupdateGradInput_functor(min_val, max_val));
}
else
{
THCudaTensor_resizeAs(state, gradInput, input);
THC_pointwiseApply3(state, gradInput, input, gradOutput,
hardtanhupdateGradInput_functor(min_val, max_val));
}
}

View File

@ -24,7 +24,7 @@ void THNN_CudaSpatialConvolutionMM_updateOutput(THCState *state, THCudaTensor *i
if (weight->nDimension == 4) {
long s1 = weight->size[0];
long s2 = weight->size[1] * weight->size[2] * weight->size[3];
weight = THCudaTensor_newWithStorage2d(state, weight->storage, 0, s1, -1, s2, -1);
weight = THCudaTensor_newWithStorage2d(state, weight->storage, weight->storageOffset, s1, -1, s2, -1);
freeWeight = 1;
}
@ -155,7 +155,7 @@ void THNN_CudaSpatialConvolutionMM_updateGradInput(THCState *state, THCudaTensor
if (weight->nDimension == 4) {
long s1 = weight->size[0];
long s2 = weight->size[1] * weight->size[2] * weight->size[3];
weight = THCudaTensor_newWithStorage2d(state, weight->storage, 0, s1, -1, s2, -1);
weight = THCudaTensor_newWithStorage2d(state, weight->storage, weight->storageOffset, s1, -1, s2, -1);
freeWeight = 1;
}
@ -252,7 +252,7 @@ void THNN_CudaSpatialConvolutionMM_accGradParameters(THCState *state, THCudaTens
if (gradWeight->nDimension == 4) {
long s1 = gradWeight->size[0];
long s2 = gradWeight->size[1] * gradWeight->size[2] * gradWeight->size[3];
gradWeight = THCudaTensor_newWithStorage2d(state, gradWeight->storage, 0, s1, -1, s2, -1);
gradWeight = THCudaTensor_newWithStorage2d(state, gradWeight->storage, gradWeight->storageOffset, s1, -1, s2, -1);
freeWeight = 1;
}

View File

@ -0,0 +1,207 @@
#include "THCUNN.h"
#include "common.h"
// kernels borrowed from Caffe
template <typename Dtype>
__global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w, Dtype* top_data,
Dtype* top_mask) {
CUDA_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
int hstart = ph * stride_h - pad_h;
int wstart = pw * stride_w - pad_w;
int hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height);
int wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width);
while(hstart < 0)
hstart += dilation_h;
while(wstart < 0)
wstart += dilation_w;
Dtype maxval = -FLT_MAX;
int maxidx = -1;
bottom_data += (n * channels + c) * height * width;
for (int h = hstart; h < hend; h += dilation_h) {
for (int w = wstart; w < wend; w += dilation_w) {
if (bottom_data[h * width + w] > maxval) {
maxidx = h * width + w;
maxval = bottom_data[maxidx];
}
}
}
top_data[index] = maxval;
top_mask[index] = maxidx + TH_INDEX_BASE;
}
}
template <typename Dtype>
__global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
const Dtype* top_mask, const int num, const int channels,
const int height, const int width, const int pooled_height,
const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
Dtype* bottom_diff) {
CUDA_KERNEL_LOOP(index, nthreads) {
// find out the local index
// find out the local offset
int w = index % width;
int h = (index / width) % height;
int c = (index / width / height) % channels;
int n = index / width / height / channels;
int phstart =
(h + pad_h < ((kernel_h - 1) * dilation_h + 1)) ? 0 : (h + pad_h - ((kernel_h - 1) * dilation_h + 1)) / stride_h + 1;
int phend = min((h + pad_h) / stride_h + 1, pooled_height);
int pwstart =
(w + pad_w < ((kernel_w - 1) * dilation_w + 1)) ? 0 : (w + pad_w - ((kernel_w - 1) * dilation_w + 1)) / stride_w + 1;
int pwend = min((w + pad_w) / stride_w + 1, pooled_width);
Dtype gradient = 0;
int offset = (n * channels + c) * pooled_height * pooled_width;
top_diff += offset;
top_mask += offset;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
if (top_mask[ph * pooled_width + pw] - TH_INDEX_BASE == h * width + w) {
gradient += top_diff[ph * pooled_width + pw];
}
}
}
bottom_diff[index] = gradient;
}
}
void THNN_CudaSpatialDilatedMaxPooling_updateOutput(THCState *state, THCudaTensor *input, THCudaTensor *output, THCudaTensor *indices, int kW, int kH, int dW, int dH, int padW, int padH, int dilationW, int dilationH, bool ceil_mode)
{
THCUNN_assertSameGPU(state, 3, input, output, indices);
THArgCheck(input->nDimension == 3 || input->nDimension == 4, 2, "3D or 4D (batch) tensor expected");
long nInputCols, nInputRows, nInputPlane, batchSize;
long nOutputCols, nOutputRows;
if (input->nDimension == 3) {
nInputCols = input->size[2];
nInputRows = input->size[1];
nInputPlane = input->size[0];
batchSize = 1;
}
else
{
nInputCols = input->size[3];
nInputRows = input->size[2];
nInputPlane = input->size[1];
batchSize = input->size[0];
}
THArgCheck(nInputCols >= kW - padW && nInputRows >= kH - padH, 2, "input image smaller than kernel size");
THArgCheck(kW/2 >= padW && kH/2 >= padH, 2, "pad should be smaller than half of kernel size");
if(ceil_mode) {
nOutputCols = ceil(float(nInputCols - (dilationW * (kW - 1) + 1) + 2*padW) / float(dW)) + 1;
nOutputRows = ceil(float(nInputRows - (dilationH * (kH - 1) + 1) + 2*padH) / float(dH)) + 1;
}
else {
nOutputCols = floor(float(nInputCols - (dilationW * (kW - 1) + 1) + 2*padW) / float(dW)) + 1;
nOutputRows = floor(float(nInputRows - (dilationH * (kH - 1) + 1) + 2*padH) / float(dH)) + 1;
}
if (nOutputCols < 1 || nOutputRows < 1)
THError("Given input size: (%dx%dx%d). Calculated output size: (%dx%dx%d). Output size is too small",
nInputPlane,nInputRows,nInputCols,nInputPlane,nOutputRows,nOutputCols);
if (padW || padH)
{
// ensure that the last pooling starts inside the image
if ((nOutputRows - 1)*dH >= nInputRows + padH)
--nOutputRows;
if ((nOutputCols - 1)*dW >= nInputCols + padW)
--nOutputCols;
}
input = THCudaTensor_newContiguous(state, input);
float* input_data = THCudaTensor_data(state, input);
THCudaTensor_resize4d(state, output, batchSize, nInputPlane, nOutputRows, nOutputCols);
THCudaTensor_resizeAs(state, indices, output);
float* indices_data = THCudaTensor_data(state, indices);
float* output_data = THCudaTensor_data(state, output);
int count = THCudaTensor_nElement(state, output);
MaxPoolForward <<< GET_BLOCKS(count), CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>
(count, input_data,
batchSize, nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols,
kH, kW, dH, dW, padH, padW, dilationH, dilationW, output_data, indices_data);
THCudaCheck(cudaGetLastError());
if(input->nDimension == 3)
THCudaTensor_resize3d(state, output, nInputPlane, nOutputRows, nOutputCols);
THCudaTensor_free(state, input);
}
void THNN_CudaSpatialDilatedMaxPooling_updateGradInput(THCState *state, THCudaTensor *input, THCudaTensor *gradOutput, THCudaTensor *gradInput, THCudaTensor *indices, int kW, int kH, int dW, int dH, int padW, int padH, int dilationW, int dilationH, bool ceil_mode)
{
THCUNN_assertSameGPU(state, 4, input, gradOutput, indices, gradInput);
input = THCudaTensor_newContiguous(state, input);
gradOutput = THCudaTensor_newContiguous(state, gradOutput);
long nInputCols, nInputRows, nInputPlane, batchSize;
long nOutputCols, nOutputRows;
if (input->nDimension == 3) {
nInputCols = input->size[2];
nInputRows = input->size[1];
nInputPlane = input->size[0];
batchSize = 1;
}
else
{
nInputCols = input->size[3];
nInputRows = input->size[2];
nInputPlane = input->size[1];
batchSize = input->size[0];
}
if(ceil_mode) {
nOutputCols = ceil(float(nInputCols - (dilationW * (kW - 1) + 1) + 2*padW) / float(dW)) + 1;
nOutputRows = ceil(float(nInputRows - (dilationH * (kH - 1) + 1) + 2*padH) / float(dH)) + 1;
}
else {
nOutputCols = floor(float(nInputCols - (dilationW * (kW - 1) + 1) + 2*padW) / float(dW)) + 1;
nOutputRows = floor(float(nInputRows - (dilationH * (kH - 1) + 1) + 2*padH) / float(dH)) + 1;
}
if (nOutputCols < 1 || nOutputRows < 1)
THError("Given input size: (%dx%dx%d). Calculated output size: (%dx%dx%d). Output size is too small",
nInputPlane,nInputRows,nInputCols,nInputPlane,nOutputRows,nOutputCols);
gradOutput = THCudaTensor_newContiguous(state, gradOutput);
THCudaTensor_resizeAs(state, gradInput, input);
int count = THCudaTensor_nElement(state, input);
MaxPoolBackward <<< GET_BLOCKS(count), CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>
(count,
THCudaTensor_data(state, gradOutput),
THCudaTensor_data(state, indices),
batchSize, nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
THCudaTensor_data(state, gradInput));
THCudaCheck(cudaGetLastError());
THCudaTensor_free(state, gradOutput);
// clean
THCudaTensor_free(state, input);
THCudaTensor_free(state, gradOutput);
}

View File

@ -1,207 +1,18 @@
#include "THCUNN.h"
#include "common.h"
// kernels borrowed from Caffe
template <typename Dtype>
__global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w, Dtype* top_data,
Dtype* top_mask) {
CUDA_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
int hstart = ph * stride_h - pad_h;
int wstart = pw * stride_w - pad_w;
int hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height);
int wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width);
while(hstart < 0)
hstart += dilation_h;
while(wstart < 0)
wstart += dilation_w;
Dtype maxval = -FLT_MAX;
int maxidx = -1;
bottom_data += (n * channels + c) * height * width;
for (int h = hstart; h < hend; h += dilation_h) {
for (int w = wstart; w < wend; w += dilation_w) {
if (bottom_data[h * width + w] > maxval) {
maxidx = h * width + w;
maxval = bottom_data[maxidx];
}
}
}
top_data[index] = maxval;
top_mask[index] = maxidx + TH_INDEX_BASE;
}
}
template <typename Dtype>
__global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
const Dtype* top_mask, const int num, const int channels,
const int height, const int width, const int pooled_height,
const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
Dtype* bottom_diff) {
CUDA_KERNEL_LOOP(index, nthreads) {
// find out the local index
// find out the local offset
int w = index % width;
int h = (index / width) % height;
int c = (index / width / height) % channels;
int n = index / width / height / channels;
int phstart =
(h + pad_h < ((kernel_h - 1) * dilation_h + 1)) ? 0 : (h + pad_h - ((kernel_h - 1) * dilation_h + 1)) / stride_h + 1;
int phend = min((h + pad_h) / stride_h + 1, pooled_height);
int pwstart =
(w + pad_w < ((kernel_w - 1) * dilation_w + 1)) ? 0 : (w + pad_w - ((kernel_w - 1) * dilation_w + 1)) / stride_w + 1;
int pwend = min((w + pad_w) / stride_w + 1, pooled_width);
Dtype gradient = 0;
int offset = (n * channels + c) * pooled_height * pooled_width;
top_diff += offset;
top_mask += offset;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
if (top_mask[ph * pooled_width + pw] - TH_INDEX_BASE == h * width + w) {
gradient += top_diff[ph * pooled_width + pw];
}
}
}
bottom_diff[index] = gradient;
}
}
void THNN_CudaSpatialMaxPooling_updateOutput(THCState *state, THCudaTensor *input, THCudaTensor *output, THCudaTensor *indices, int kW, int kH, int dW, int dH, int padW, int padH, int dilationW, int dilationH, bool ceil_mode)
void THNN_CudaSpatialMaxPooling_updateOutput(THCState *state, THCudaTensor *input, THCudaTensor *output, THCudaTensor *indices, int kW, int kH, int dW, int dH, int padW, int padH, bool ceil_mode)
{
THNN_CudaSpatialDilatedMaxPooling_updateOutput(
state, input, output, indices,
kW, kH, dW, dH, padW, padH, 1, 1, ceil_mode);
THCUNN_assertSameGPU(state, 3, input, output, indices);
THArgCheck(input->nDimension == 3 || input->nDimension == 4, 2, "3D or 4D (batch) tensor expected");
long nInputCols, nInputRows, nInputPlane, batchSize;
long nOutputCols, nOutputRows;
if (input->nDimension == 3) {
nInputCols = input->size[2];
nInputRows = input->size[1];
nInputPlane = input->size[0];
batchSize = 1;
}
else
{
nInputCols = input->size[3];
nInputRows = input->size[2];
nInputPlane = input->size[1];
batchSize = input->size[0];
}
THArgCheck(nInputCols >= kW - padW && nInputRows >= kH - padH, 2, "input image smaller than kernel size");
THArgCheck(kW/2 >= padW && kH/2 >= padH, 2, "pad should be smaller than half of kernel size");
if(ceil_mode) {
nOutputCols = ceil(float(nInputCols - (dilationW * (kW - 1) + 1) + 2*padW) / float(dW)) + 1;
nOutputRows = ceil(float(nInputRows - (dilationH * (kH - 1) + 1) + 2*padH) / float(dH)) + 1;
}
else {
nOutputCols = floor(float(nInputCols - (dilationW * (kW - 1) + 1) + 2*padW) / float(dW)) + 1;
nOutputRows = floor(float(nInputRows - (dilationH * (kH - 1) + 1) + 2*padH) / float(dH)) + 1;
}
if (nOutputCols < 1 || nOutputRows < 1)
THError("Given input size: (%dx%dx%d). Calculated output size: (%dx%dx%d). Output size is too small",
nInputPlane,nInputRows,nInputCols,nInputPlane,nOutputRows,nOutputCols);
if (padW || padH)
{
// ensure that the last pooling starts inside the image
if ((nOutputRows - 1)*dH >= nInputRows + padH)
--nOutputRows;
if ((nOutputCols - 1)*dW >= nInputCols + padW)
--nOutputCols;
}
input = THCudaTensor_newContiguous(state, input);
float* input_data = THCudaTensor_data(state, input);
THCudaTensor_resize4d(state, output, batchSize, nInputPlane, nOutputRows, nOutputCols);
THCudaTensor_resizeAs(state, indices, output);
float* indices_data = THCudaTensor_data(state, indices);
float* output_data = THCudaTensor_data(state, output);
int count = THCudaTensor_nElement(state, output);
MaxPoolForward <<< GET_BLOCKS(count), CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>
(count, input_data,
batchSize, nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols,
kH, kW, dH, dW, padH, padW, dilationH, dilationW, output_data, indices_data);
THCudaCheck(cudaGetLastError());
if(input->nDimension == 3)
THCudaTensor_resize3d(state, output, nInputPlane, nOutputRows, nOutputCols);
THCudaTensor_free(state, input);
}
void THNN_CudaSpatialMaxPooling_updateGradInput(THCState *state, THCudaTensor *input, THCudaTensor *gradOutput, THCudaTensor *gradInput, THCudaTensor *indices, int kW, int kH, int dW, int dH, int padW, int padH, int dilationW, int dilationH, bool ceil_mode)
void THNN_CudaSpatialMaxPooling_updateGradInput(THCState *state, THCudaTensor *input, THCudaTensor *gradOutput, THCudaTensor *gradInput, THCudaTensor *indices, int kW, int kH, int dW, int dH, int padW, int padH, bool ceil_mode)
{
THCUNN_assertSameGPU(state, 4, input, gradOutput, indices, gradInput);
THNN_CudaSpatialDilatedMaxPooling_updateGradInput(
state, input, gradOutput, gradInput, indices,
kW, kH, dW, dH, padW, padH, 1, 1, ceil_mode);
input = THCudaTensor_newContiguous(state, input);
gradOutput = THCudaTensor_newContiguous(state, gradOutput);
long nInputCols, nInputRows, nInputPlane, batchSize;
long nOutputCols, nOutputRows;
if (input->nDimension == 3) {
nInputCols = input->size[2];
nInputRows = input->size[1];
nInputPlane = input->size[0];
batchSize = 1;
}
else
{
nInputCols = input->size[3];
nInputRows = input->size[2];
nInputPlane = input->size[1];
batchSize = input->size[0];
}
if(ceil_mode) {
nOutputCols = ceil(float(nInputCols - (dilationW * (kW - 1) + 1) + 2*padW) / float(dW)) + 1;
nOutputRows = ceil(float(nInputRows - (dilationH * (kH - 1) + 1) + 2*padH) / float(dH)) + 1;
}
else {
nOutputCols = floor(float(nInputCols - (dilationW * (kW - 1) + 1) + 2*padW) / float(dW)) + 1;
nOutputRows = floor(float(nInputRows - (dilationH * (kH - 1) + 1) + 2*padH) / float(dH)) + 1;
}
if (nOutputCols < 1 || nOutputRows < 1)
THError("Given input size: (%dx%dx%d). Calculated output size: (%dx%dx%d). Output size is too small",
nInputPlane,nInputRows,nInputCols,nInputPlane,nOutputRows,nOutputCols);
gradOutput = THCudaTensor_newContiguous(state, gradOutput);
THCudaTensor_resizeAs(state, gradInput, input);
int count = THCudaTensor_nElement(state, input);
MaxPoolBackward <<< GET_BLOCKS(count), CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>
(count,
THCudaTensor_data(state, gradOutput),
THCudaTensor_data(state, indices),
batchSize, nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
THCudaTensor_data(state, gradInput));
THCudaCheck(cudaGetLastError());
THCudaTensor_free(state, gradOutput);
// clean
THCudaTensor_free(state, input);
THCudaTensor_free(state, gradOutput);
}

View File

@ -27,6 +27,21 @@ TH_API void THNN_CudaAbsCriterion_updateGradInput(
THCudaTensor *gradInput,
bool sizeAverage);
TH_API void THNN_CudaBCECriterion_updateOutput(
THCState *state,
THCudaTensor *input,
THCudaTensor *target,
THCudaTensor *output,
bool sizeAverage,
THCudaTensor *weights);
TH_API void THNN_CudaBCECriterion_updateGradInput(
THCState *state,
THCudaTensor *input,
THCudaTensor *target,
THCudaTensor *gradInput,
bool sizeAverage,
THCudaTensor *weights);
TH_API void THNN_CudaClassNLLCriterion_updateOutput(
THCState *state,
THCudaTensor *input,
@ -735,9 +750,29 @@ TH_API void THNN_CudaSpatialMaxPooling_updateOutput(
int kW, int kH,
int dW, int dH,
int padW, int padH,
int dilationW, int dilationH,
bool ceil_mode);
TH_API void THNN_CudaSpatialMaxPooling_updateGradInput(
THCState *state,
THCudaTensor *input,
THCudaTensor *gradOutput,
THCudaTensor *gradInput,
THCudaTensor *indices,
int kW, int kH,
int dW, int dH,
int padW, int padH,
bool ceil_mode);
TH_API void THNN_CudaSpatialDilatedMaxPooling_updateOutput(
THCState *state,
THCudaTensor *input,
THCudaTensor *output,
THCudaTensor *indices,
int kW, int kH,
int dW, int dH,
int padW, int padH,
int dilationW, int dilationH,
bool ceil_mode);
TH_API void THNN_CudaSpatialDilatedMaxPooling_updateGradInput(
THCState *state,
THCudaTensor *input,
THCudaTensor *gradOutput,
@ -964,6 +999,26 @@ TH_API void THNN_CudaVolumetricMaxPooling_updateGradInput(
int dT, int dW, int dH,
int padT, int padW, int padH);
TH_API void THNN_CudaVolumetricDilatedMaxPooling_updateOutput(
THCState *state,
THCudaTensor *input,
THCudaTensor *output,
THCudaTensor *indices,
int kT, int kW, int kH,
int dT, int dW, int dH,
int padT, int padW, int padH,
int dilationT, int dilationW, int dilationH,
bool ceilMode);
TH_API void THNN_CudaVolumetricDilatedMaxPooling_updateGradInput(
THCState *state,
THCudaTensor *input,
THCudaTensor *gradOutput,
THCudaTensor *gradInput,
THCudaTensor *indices,
int dT, int dW, int dH,
int padT, int padW, int padH,
int dilationT, int dilationW, int dilationH);
TH_API void THNN_CudaVolumetricMaxUnpooling_updateOutput(
THCState *state,
THCudaTensor *input,

View File

@ -0,0 +1,437 @@
#include "THCUNN.h"
#include "common.h"
#include "THCDeviceTensor.cuh"
#include "THCDeviceTensorUtils.cuh"
#include "THCDeviceUtils.cuh"
#include <cfloat>
__global__ void cuda_VolumetricDilatedMaxPooling_updateOutput(
THCDeviceTensor<float, 4> input,
THCDeviceTensor<float, 4> indices,
THCDeviceTensor<float, 4> output,
int kT, int kH, int kW,
int dT, int dH, int dW,
int padT, int padH, int padW,
int dilationT, int dilationH, int dilationW,
int offsetZ)
{
int oColumn = blockIdx.x * blockDim.x + threadIdx.x;
int oRow = blockIdx.y * blockDim.y + threadIdx.y;
int oFrame = (blockIdx.z + offsetZ) % output.getSize(1); // output frame/time
int slice = (blockIdx.z + offsetZ) / output.getSize(1); // output slice/feature
if (oRow < output.getSize(2) && oColumn < output.getSize(3))
{
int iColumn = oColumn * dW - padW;
int iRow = oRow * dH - padH;
int iFrame = oFrame * dT - padT;
int maxColumn = 0;
int maxRow = 0;
int maxFrame = 0;
float max = -FLT_MAX;
for (int frame = 0; frame < kT; ++frame)
{
if (iFrame + frame * dilationT < input.getSize(1) && iFrame + frame * dilationT >= 0)
{
for (int row = 0; row < kH; ++row)
{
if (iRow + row * dilationH < input.getSize(2) && iRow + row * dilationH >= 0)
{
for (int column = 0; column < kW; ++column)
{
if (iColumn + column * dilationW < input.getSize(3) && iColumn + column * dilationW >= 0)
{
float val = input[slice][iFrame + frame * dilationT][iRow + row * dilationH][iColumn + column * dilationW];
if (max < val)
{
max = val;
maxColumn = column;
maxRow = row;
maxFrame = frame;
}
}
}
}
}
}
}
output[slice][oFrame][oRow][oColumn] = max;
float *idx = &indices[slice][oFrame][oRow][oColumn];
((unsigned char*)(idx))[0] = maxFrame;
((unsigned char*)(idx))[1] = maxRow;
((unsigned char*)(idx))[2] = maxColumn;
((unsigned char*)(idx))[3] = 0;
}
}
template <int KERNEL_WIDTH>
__global__ void cuda_VolumetricDilatedMaxPooling_updateOutput(
THCDeviceTensor<float, 4> input, THCDeviceTensor<float, 4> indices,
THCDeviceTensor<float, 4> output,
int kT, int kH,
int dT, int dH, int dW,
int padT, int padH, int padW,
int dilationT, int dilationH, int dilationW,
int offsetZ)
{
int oColumn = blockIdx.x * blockDim.x + threadIdx.x;
int oRow = blockIdx.y * blockDim.y + threadIdx.y;
int oFrame = (blockIdx.z + offsetZ) % output.getSize(1); // output frame/time
int slice = (blockIdx.z + offsetZ) / output.getSize(1); // output slice/feature
if (oRow < output.getSize(2) && oColumn < output.getSize(3))
{
int iColumn = oColumn * dW - padW;
int iRow = oRow * dH - padH;
int iFrame = oFrame * dT - padT;
int maxColumn = 0;
int maxRow = 0;
int maxFrame;
float max = -FLT_MAX;
for (int frame = 0; frame < kT; ++frame)
{
if (iFrame + frame * dilationT < input.getSize(1) && iFrame + frame * dilationT >= 0)
{
for (int row = 0; row < kH; ++row)
{
if (iRow + row * dilationH < input.getSize(2) && iRow + row * dilationH >= 0)
{
for (int column = 0; column < KERNEL_WIDTH; ++column)
{
if (iColumn + column * dilationW < input.getSize(3) && iColumn + column * dilationW >= 0)
{
float val = input[slice][iFrame + frame * dilationT][iRow + row * dilationH][iColumn + column * dilationW];
if (max < val)
{
max = val;
maxColumn = column;
maxRow = row;
maxFrame = frame;
}
}
}
}
}
}
}
output[slice][oFrame][oRow][oColumn] = max;
float *idx = &indices[slice][oFrame][oRow][oColumn];
((unsigned char*)(idx))[0] = maxFrame;
((unsigned char*)(idx))[1] = maxRow;
((unsigned char*)(idx))[2] = maxColumn;
((unsigned char*)(idx))[3] = 0;
}
}
#define UPDATE_OUTPUT_KERNEL_WIDTH(KW) case KW: \
cuda_VolumetricDilatedMaxPooling_updateOutput<KW><<<grid, block, \
0, THCState_getCurrentStream(state)>>>( \
cudaInput, cudaIndices, cudaOutput, kT, kH, dT, dH, dW, padT, padH, padW,\
dilationT, dilationH, dilationW, offsetZ); \
break
void THNN_CudaVolumetricDilatedMaxPooling_updateOutput(
THCState *state, THCudaTensor *input, THCudaTensor *output, THCudaTensor *indices,
int kT, int kW, int kH,
int dT, int dW, int dH,
int padT, int padW, int padH,
int dilationT, int dilationW, int dilationH,
bool ceilMode)
{
int batchSize;
int inputSlices;
int inputTime;
int inputHeight;
int inputWidth;
int outputTime;
int outputHeight;
int outputWidth;
THCUNN_assertSameGPU(state, 3, input, indices, output);
if (THCudaTensor_nDimension(state, input) == 4)
{
THArgCheck(
THCudaTensor_size(state, input, 1) >= kT &&
THCudaTensor_size(state, input, 2) >= kH &&
THCudaTensor_size(state, input, 3) >= kW, 2,
"input image smaller than kernel size"
);
/* sizes */
batchSize = 1;
inputSlices = THCudaTensor_size(state, input, 0);
inputTime = THCudaTensor_size(state, input, 1);
inputHeight = THCudaTensor_size(state, input, 2);
inputWidth = THCudaTensor_size(state, input, 3);
}
else if (THCudaTensor_nDimension(state, input) == 5)
{
THArgCheck(
THCudaTensor_size(state, input, 4) >= kW &&
THCudaTensor_size(state, input, 3) >= kH &&
THCudaTensor_size(state, input, 2) >= kT, 2,
"input image smaller than kernel size"
);
/* sizes */
batchSize = THCudaTensor_size(state, input, 0);
inputSlices = THCudaTensor_size(state, input, 1);
inputTime = THCudaTensor_size(state, input, 2);
inputHeight = THCudaTensor_size(state, input, 3);
inputWidth = THCudaTensor_size(state, input, 4);
}
else
{
THArgCheck(false, 2, "4D or 5D tensor expected");
}
THArgCheck(kT/2 >= padT && kW/2 >= padW && kH/2 >= padH, 2,
"pad should be smaller than half of kernel size"
);
if (ceilMode)
{
outputTime = (int)(ceil((float)(inputTime - (dilationT * (kT - 1) + 1) + 2*padT) / dT)) + 1;
outputHeight = (int)(ceil((float)(inputHeight - (dilationH * (kH - 1) + 1) + 2*padH) / dH)) + 1;
outputWidth = (int)(ceil((float)(inputWidth - (dilationW * (kW - 1) + 1) + 2*padW) / dW)) + 1;
}
else
{
outputTime = (int)(floor((float)(inputTime - (dilationT * (kT - 1) + 1) + 2*padT) / dT)) + 1;
outputHeight = (int)(floor((float)(inputHeight - (dilationH * (kH - 1) + 1) + 2*padH) / dH)) + 1;
outputWidth = (int)(floor((float)(inputWidth - (dilationW * (kW - 1) + 1) + 2*padW) / dW)) + 1;
}
if (outputTime < 1 || outputHeight < 1 || outputWidth < 1)
THError("Given input size: (%dx%dx%dx%d). Calculated output size: (%dx%dx%dx%d). Output size is too small",
inputSlices,inputTime,inputHeight,inputWidth,inputSlices,outputTime,outputHeight,outputWidth);
if (padT || padW || padH)
{
if ((outputTime - 1)*dT >= inputTime + padT)
--outputTime;
if ((outputHeight - 1)*dH >= inputHeight + padH)
--outputHeight;
if ((outputWidth - 1)*dW >= inputWidth + padW)
--outputWidth;
}
if (input->nDimension == 4) /* 4D */
{
/* resize output */
THCudaTensor_resize4d(state, output, inputSlices,
outputTime, outputHeight, outputWidth);
/* indices pack ti,i,j locations for each output point as uchar into
each float of the tensor */
THCudaTensor_resize4d(state, indices, inputSlices,
outputTime, outputHeight, outputWidth);
}
else
{ /* 5D */
THCudaTensor_resize5d(state, output, batchSize, inputSlices,
outputTime, outputHeight, outputWidth);
// Index tensor packs index offsets as uchars into floats
THCudaTensor_resize5d(state, indices, batchSize, inputSlices,
outputTime, outputHeight, outputWidth);
}
input = THCudaTensor_newContiguous(state, input);
// Collapse batch and feature dimensions
THCDeviceTensor<float, 4> cudaInput;
THCDeviceTensor<float, 4> cudaOutput;
if (THCudaTensor_nDimension(state, input) == 4)
{
cudaInput = toDeviceTensor<float, 4>(state, input);
cudaOutput = toDeviceTensor<float, 4>(state, output);
}
else
{
cudaInput = toDeviceTensor<float, 5>(state, input).downcastOuter<4>();
cudaOutput = toDeviceTensor<float, 5>(state, output).downcastOuter<4>();
}
THLongStorage *indicesSize = THLongStorage_newWithSize(4);
long indicesSizeRaw[4] = { batchSize * inputSlices,
outputTime, outputHeight, outputWidth };
THLongStorage_rawCopy(indicesSize, indicesSizeRaw);
THCudaTensor *indices1 = THCudaTensor_newWithStorage(
state, THCudaTensor_storage(state, indices),
THCudaTensor_storageOffset(state, indices),
indicesSize, NULL);
THLongStorage_free(indicesSize);
THCDeviceTensor<float, 4> cudaIndices =
toDeviceTensor<float, 4>(state, indices1);
int totalZ = outputTime * inputSlices * batchSize;
int offsetZ = 0;
dim3 block(32, 8);
while (totalZ > 0) {
dim3 grid(THCCeilDiv(outputWidth, static_cast<int>(block.x)),
THCCeilDiv(outputHeight, static_cast<int>(block.y)),
totalZ > 65535 ? 65535 : totalZ);
switch (kW)
{
UPDATE_OUTPUT_KERNEL_WIDTH(1);
UPDATE_OUTPUT_KERNEL_WIDTH(2);
UPDATE_OUTPUT_KERNEL_WIDTH(3);
UPDATE_OUTPUT_KERNEL_WIDTH(4);
UPDATE_OUTPUT_KERNEL_WIDTH(5);
UPDATE_OUTPUT_KERNEL_WIDTH(6);
UPDATE_OUTPUT_KERNEL_WIDTH(7);
default:
cuda_VolumetricDilatedMaxPooling_updateOutput<<<grid, block,
0, THCState_getCurrentStream(state)>>>(
cudaInput, cudaIndices, cudaOutput,
kT, kH, kW, dT, dH, dW,
padT, padH, padW, dilationT, dilationH, dilationW, offsetZ);
}
THCudaCheck(cudaGetLastError());
totalZ -= 65535;
offsetZ += 65535;
}
THCudaTensor_free(state, input);
THCudaTensor_free(state, indices1);
}
#undef UPDATE_OUTPUT_KERNEL_WIDTH
__global__ void cuda_VolumetricDilatedMaxPooling_updateGradInput(
THCDeviceTensor<float, 4> gradOutput,
THCDeviceTensor<float, 4> indices,
THCDeviceTensor<float, 4> gradInput,
int dT, int dH, int dW,
int padT, int padH, int padW,
int dilationT, int dilationH, int dilationW,
int offsetZ)
{
int oColumn = blockIdx.x * blockDim.x + threadIdx.x;
int oRow = blockIdx.y * blockDim.y + threadIdx.y;
int oFrame = (blockIdx.z + offsetZ) % gradOutput.getSize(1); // output frame/time
int slice = (blockIdx.z + offsetZ) / gradOutput.getSize(1); // output slice/feature
if (oRow < gradOutput.getSize(2) && oColumn < gradOutput.getSize(3))
{
float *idx = &indices[slice][oFrame][oRow][oColumn];
int iFrame = ((unsigned char*)(idx))[0] * dilationT + oFrame * dT - padT;
int iRow = ((unsigned char*)(idx))[1] * dilationH + oRow * dH - padH;
int iColumn = ((unsigned char*)(idx))[2] * dilationW + oColumn * dW - padW;
atomicAdd(&gradInput[slice][iFrame][iRow][iColumn],
gradOutput[slice][oFrame][oRow][oColumn]);
}
}
void THNN_CudaVolumetricDilatedMaxPooling_updateGradInput(
THCState *state, THCudaTensor *input, THCudaTensor *gradOutput, THCudaTensor *gradInput,
THCudaTensor *indices,
int dT, int dW, int dH,
int padT, int padW, int padH,
int dilationT, int dilationW, int dilationH)
{
// Resize and initialize result tensor.
THCudaTensor_resizeAs(state, gradInput, input);
THCudaTensor_zero(state, gradInput);
int batchSize;
int inputSlices;
int outputTime;
int outputHeight;
int outputWidth;
THCUNN_assertSameGPU(state, 4, input, indices, gradOutput, gradInput);
if (THCudaTensor_nDimension(state, input) == 4) /* 4D */
{
batchSize = 1;
inputSlices = THCudaTensor_size(state, input, 0);
outputTime = THCudaTensor_size(state, gradOutput, 1);
outputHeight = THCudaTensor_size(state, gradOutput, 2);
outputWidth = THCudaTensor_size(state, gradOutput, 3);
}
else
{
batchSize = THCudaTensor_size(state, input, 0);
inputSlices = THCudaTensor_size(state, input, 1);
outputTime = THCudaTensor_size(state, gradOutput, 2);
outputHeight = THCudaTensor_size(state, gradOutput, 3);
outputWidth = THCudaTensor_size(state, gradOutput, 4);
}
gradOutput = THCudaTensor_newContiguous(state, gradOutput);
// Collapse batch and feature dimensions
THCDeviceTensor<float, 4> cudaGradInput;
THCDeviceTensor<float, 4> cudaGradOutput;
if (THCudaTensor_nDimension(state, input) == 4)
{
cudaGradInput = toDeviceTensor<float, 4>(state, gradInput);
cudaGradOutput = toDeviceTensor<float, 4>(state, gradOutput);
}
else
{
cudaGradInput =
toDeviceTensor<float, 5>(state, gradInput).downcastOuter<4>();
cudaGradOutput =
toDeviceTensor<float, 5>(state, gradOutput).downcastOuter<4>();
}
THLongStorage *indicesSize = THLongStorage_newWithSize(4);
long indicesSizeRaw[4] = { batchSize * inputSlices,
outputTime, outputHeight, outputWidth };
THLongStorage_rawCopy(indicesSize, indicesSizeRaw);
THCudaTensor *indices1 = THCudaTensor_newWithStorage(
state, THCudaTensor_storage(state, indices),
THCudaTensor_storageOffset(state, indices), indicesSize, NULL);
THLongStorage_free(indicesSize);
THCDeviceTensor<float, 4> cudaIndices =
toDeviceTensor<float, 4>(state, indices1);
int totalZ = outputTime * inputSlices * batchSize;
int offsetZ = 0;
dim3 block(32, 8);
while (totalZ > 0) {
dim3 grid(THCCeilDiv(outputWidth, static_cast<int>(block.x)),
THCCeilDiv(outputHeight, static_cast<int>(block.y)),
totalZ > 65535 ? 65535 : totalZ);
cuda_VolumetricDilatedMaxPooling_updateGradInput<<<grid, block,
0, THCState_getCurrentStream(state)>>>(
cudaGradOutput,
cudaIndices,
cudaGradInput,
dT, dH, dW,
padT, padH, padW,
dilationT, dilationH, dilationW, offsetZ);
THCudaCheck(cudaGetLastError());
totalZ -= 65535;
offsetZ += 65535;
}
// cleanup
THCudaTensor_free(state, gradOutput);
THCudaTensor_free(state, indices1);
}

View File

@ -6,137 +6,6 @@
#include <cfloat>
__global__ void cuda_VolumetricMaxPooling_updateOutput(
THCDeviceTensor<float, 4> input,
THCDeviceTensor<float, 4> indices,
THCDeviceTensor<float, 4> output,
int kT, int kH, int kW,
int dT, int dH, int dW,
int padT, int padH, int padW, int offsetZ)
{
int oColumn = blockIdx.x * blockDim.x + threadIdx.x;
int oRow = blockIdx.y * blockDim.y + threadIdx.y;
int oFrame = (blockIdx.z + offsetZ) % output.getSize(1); // output frame/time
int slice = (blockIdx.z + offsetZ) / output.getSize(1); // output slice/feature
if (oRow < output.getSize(2) && oColumn < output.getSize(3))
{
int iColumn = oColumn * dW - padW;
int iRow = oRow * dH - padH;
int iFrame = oFrame * dT - padT;
int maxColumn = 0;
int maxRow = 0;
int maxFrame = 0;
float max = -FLT_MAX;
for (int frame = 0; frame < kT; ++frame)
{
if (iFrame + frame < input.getSize(1) && iFrame + frame >= 0)
{
for (int row = 0; row < kH; ++row)
{
if (iRow + row < input.getSize(2) && iRow + row >= 0)
{
for (int column = 0; column < kW; ++column)
{
if (iColumn + column < input.getSize(3) && iColumn + column >= 0)
{
float val = input[slice][iFrame + frame][iRow + row][iColumn + column];
if (max < val)
{
max = val;
maxColumn = column;
maxRow = row;
maxFrame = frame;
}
}
}
}
}
}
}
output[slice][oFrame][oRow][oColumn] = max;
float *idx = &indices[slice][oFrame][oRow][oColumn];
((unsigned char*)(idx))[0] = maxFrame;
((unsigned char*)(idx))[1] = maxRow;
((unsigned char*)(idx))[2] = maxColumn;
((unsigned char*)(idx))[3] = 0;
}
}
template <int KERNEL_WIDTH>
__global__ void cuda_VolumetricMaxPooling_updateOutput(
THCDeviceTensor<float, 4> input, THCDeviceTensor<float, 4> indices,
THCDeviceTensor<float, 4> output,
int kT, int kH,
int dT, int dH, int dW,
int padT, int padH, int padW, int offsetZ)
{
int oColumn = blockIdx.x * blockDim.x + threadIdx.x;
int oRow = blockIdx.y * blockDim.y + threadIdx.y;
int oFrame = (blockIdx.z + offsetZ) % output.getSize(1); // output frame/time
int slice = (blockIdx.z + offsetZ) / output.getSize(1); // output slice/feature
if (oRow < output.getSize(2) && oColumn < output.getSize(3))
{
int iColumn = oColumn * dW - padW;
int iRow = oRow * dH - padH;
int iFrame = oFrame * dT - padT;
int maxColumn = 0;
int maxRow = 0;
int maxFrame;
float max = -FLT_MAX;
for (int frame = 0; frame < kT; ++frame)
{
if (iFrame + frame < input.getSize(1) && iFrame + frame >= 0)
{
for (int row = 0; row < kH; ++row)
{
if (iRow + row < input.getSize(2) && iRow + row >= 0)
{
for (int column = 0; column < KERNEL_WIDTH; ++column)
{
if (iColumn + column < input.getSize(3) && iColumn + column >= 0)
{
float val = input[slice][iFrame + frame][iRow + row][iColumn + column];
if (max < val)
{
max = val;
maxColumn = column;
maxRow = row;
maxFrame = frame;
}
}
}
}
}
}
}
output[slice][oFrame][oRow][oColumn] = max;
float *idx = &indices[slice][oFrame][oRow][oColumn];
((unsigned char*)(idx))[0] = maxFrame;
((unsigned char*)(idx))[1] = maxRow;
((unsigned char*)(idx))[2] = maxColumn;
((unsigned char*)(idx))[3] = 0;
}
}
#define UPDATE_OUTPUT_KERNEL_WIDTH(KW) case KW: \
cuda_VolumetricMaxPooling_updateOutput<KW><<<grid, block, \
0, THCState_getCurrentStream(state)>>>( \
cudaInput, cudaIndices, cudaOutput, kT, kH, dT, dH, dW, padT, padH, padW, offsetZ); \
break
void THNN_CudaVolumetricMaxPooling_updateOutput(
THCState *state, THCudaTensor *input, THCudaTensor *output, THCudaTensor *indices,
int kT, int kW, int kH,
@ -144,188 +13,10 @@ void THNN_CudaVolumetricMaxPooling_updateOutput(
int padT, int padW, int padH,
bool ceilMode)
{
int batchSize;
int inputSlices;
int inputTime;
int inputHeight;
int inputWidth;
int outputTime;
int outputHeight;
int outputWidth;
THNN_CudaVolumetricDilatedMaxPooling_updateOutput(
state, input, output, indices,
kT, kW, kH, dT, dW, dH, padT, padW, padH, 1, 1, 1, ceilMode);
THCUNN_assertSameGPU(state, 3, input, indices, output);
if (THCudaTensor_nDimension(state, input) == 4)
{
THArgCheck(
THCudaTensor_size(state, input, 1) >= kT &&
THCudaTensor_size(state, input, 2) >= kH &&
THCudaTensor_size(state, input, 3) >= kW, 2,
"input image smaller than kernel size"
);
/* sizes */
batchSize = 1;
inputSlices = THCudaTensor_size(state, input, 0);
inputTime = THCudaTensor_size(state, input, 1);
inputHeight = THCudaTensor_size(state, input, 2);
inputWidth = THCudaTensor_size(state, input, 3);
}
else if (THCudaTensor_nDimension(state, input) == 5)
{
THArgCheck(
THCudaTensor_size(state, input, 4) >= kW &&
THCudaTensor_size(state, input, 3) >= kH &&
THCudaTensor_size(state, input, 2) >= kT, 2,
"input image smaller than kernel size"
);
/* sizes */
batchSize = THCudaTensor_size(state, input, 0);
inputSlices = THCudaTensor_size(state, input, 1);
inputTime = THCudaTensor_size(state, input, 2);
inputHeight = THCudaTensor_size(state, input, 3);
inputWidth = THCudaTensor_size(state, input, 4);
}
else
{
THArgCheck(false, 2, "4D or 5D tensor expected");
}
THArgCheck(kT/2 >= padT && kW/2 >= padW && kH/2 >= padH, 2,
"pad should be smaller than half of kernel size"
);
if (ceilMode)
{
outputTime = (int)(ceil((float)(inputTime - kT + 2*padT) / dT)) + 1;
outputHeight = (int)(ceil((float)(inputHeight - kH + 2*padH) / dH)) + 1;
outputWidth = (int)(ceil((float)(inputWidth - kW + 2*padW) / dW)) + 1;
}
else
{
outputTime = (int)(floor((float)(inputTime - kT + 2*padT) / dT)) + 1;
outputHeight = (int)(floor((float)(inputHeight - kH + 2*padH) / dH)) + 1;
outputWidth = (int)(floor((float)(inputWidth - kW + 2*padW) / dW)) + 1;
}
if (padT || padW || padH)
{
if ((outputTime - 1)*dT >= inputTime + padT)
--outputTime;
if ((outputHeight - 1)*dH >= inputHeight + padH)
--outputHeight;
if ((outputWidth - 1)*dW >= inputWidth + padW)
--outputWidth;
}
if (input->nDimension == 4) /* 4D */
{
/* resize output */
THCudaTensor_resize4d(state, output, inputSlices,
outputTime, outputHeight, outputWidth);
/* indices pack ti,i,j locations for each output point as uchar into
each float of the tensor */
THCudaTensor_resize4d(state, indices, inputSlices,
outputTime, outputHeight, outputWidth);
}
else
{ /* 5D */
THCudaTensor_resize5d(state, output, batchSize, inputSlices,
outputTime, outputHeight, outputWidth);
// Index tensor packs index offsets as uchars into floats
THCudaTensor_resize5d(state, indices, batchSize, inputSlices,
outputTime, outputHeight, outputWidth);
}
input = THCudaTensor_newContiguous(state, input);
// Collapse batch and feature dimensions
THCDeviceTensor<float, 4> cudaInput;
THCDeviceTensor<float, 4> cudaOutput;
if (THCudaTensor_nDimension(state, input) == 4)
{
cudaInput = toDeviceTensor<float, 4>(state, input);
cudaOutput = toDeviceTensor<float, 4>(state, output);
}
else
{
cudaInput = toDeviceTensor<float, 5>(state, input).downcastOuter<4>();
cudaOutput = toDeviceTensor<float, 5>(state, output).downcastOuter<4>();
}
THLongStorage *indicesSize = THLongStorage_newWithSize(4);
long indicesSizeRaw[4] = { batchSize * inputSlices,
outputTime, outputHeight, outputWidth };
THLongStorage_rawCopy(indicesSize, indicesSizeRaw);
THCudaTensor *indices1 = THCudaTensor_newWithStorage(
state, THCudaTensor_storage(state, indices),
THCudaTensor_storageOffset(state, indices),
indicesSize, NULL);
THLongStorage_free(indicesSize);
THCDeviceTensor<float, 4> cudaIndices =
toDeviceTensor<float, 4>(state, indices1);
int totalZ = outputTime * inputSlices * batchSize;
int offsetZ = 0;
dim3 block(32, 8);
while (totalZ > 0) {
dim3 grid(THCCeilDiv(outputWidth, static_cast<int>(block.x)),
THCCeilDiv(outputHeight, static_cast<int>(block.y)),
totalZ > 65535 ? 65535 : totalZ);
switch (kW)
{
UPDATE_OUTPUT_KERNEL_WIDTH(1);
UPDATE_OUTPUT_KERNEL_WIDTH(2);
UPDATE_OUTPUT_KERNEL_WIDTH(3);
UPDATE_OUTPUT_KERNEL_WIDTH(4);
UPDATE_OUTPUT_KERNEL_WIDTH(5);
UPDATE_OUTPUT_KERNEL_WIDTH(6);
UPDATE_OUTPUT_KERNEL_WIDTH(7);
default:
cuda_VolumetricMaxPooling_updateOutput<<<grid, block,
0, THCState_getCurrentStream(state)>>>(
cudaInput, cudaIndices, cudaOutput,
kT, kH, kW, dT, dH, dW,
padT, padH, padW, offsetZ);
}
THCudaCheck(cudaGetLastError());
totalZ -= 65535;
offsetZ += 65535;
}
THCudaTensor_free(state, input);
THCudaTensor_free(state, indices1);
}
#undef UPDATE_OUTPUT_KERNEL_WIDTH
__global__ void cuda_VolumetricMaxPooling_updateGradInput(
THCDeviceTensor<float, 4> gradOutput,
THCDeviceTensor<float, 4> indices,
THCDeviceTensor<float, 4> gradInput,
int dT, int dH, int dW,
int padT, int padH, int padW, int offsetZ)
{
int oColumn = blockIdx.x * blockDim.x + threadIdx.x;
int oRow = blockIdx.y * blockDim.y + threadIdx.y;
int oFrame = (blockIdx.z + offsetZ) % gradOutput.getSize(1); // output frame/time
int slice = (blockIdx.z + offsetZ) / gradOutput.getSize(1); // output slice/feature
if (oRow < gradOutput.getSize(2) && oColumn < gradOutput.getSize(3))
{
float *idx = &indices[slice][oFrame][oRow][oColumn];
int iFrame = ((unsigned char*)(idx))[0] + oFrame * dT - padT;
int iRow = ((unsigned char*)(idx))[1] + oRow * dH - padH;
int iColumn = ((unsigned char*)(idx))[2] + oColumn * dW - padW;
atomicAdd(&gradInput[slice][iFrame][iRow][iColumn],
gradOutput[slice][oFrame][oRow][oColumn]);
}
}
void THNN_CudaVolumetricMaxPooling_updateGradInput(
@ -334,90 +25,8 @@ void THNN_CudaVolumetricMaxPooling_updateGradInput(
int dT, int dW, int dH,
int padT, int padW, int padH)
{
// Resize and initialize result tensor.
THCudaTensor_resizeAs(state, gradInput, input);
THCudaTensor_zero(state, gradInput);
THNN_CudaVolumetricDilatedMaxPooling_updateGradInput(
state, input, gradOutput, gradInput, indices,
dT, dW, dH, padT, padW, padH, 1, 1, 1);
int batchSize;
int inputSlices;
int outputTime;
int outputHeight;
int outputWidth;
THCUNN_assertSameGPU(state, 4, input, indices, gradOutput, gradInput);
if (THCudaTensor_nDimension(state, input) == 4) /* 4D */
{
batchSize = 1;
inputSlices = THCudaTensor_size(state, input, 0);
outputTime = THCudaTensor_size(state, gradOutput, 1);
outputHeight = THCudaTensor_size(state, gradOutput, 2);
outputWidth = THCudaTensor_size(state, gradOutput, 3);
}
else
{
batchSize = THCudaTensor_size(state, input, 0);
inputSlices = THCudaTensor_size(state, input, 1);
outputTime = THCudaTensor_size(state, gradOutput, 2);
outputHeight = THCudaTensor_size(state, gradOutput, 3);
outputWidth = THCudaTensor_size(state, gradOutput, 4);
}
gradOutput = THCudaTensor_newContiguous(state, gradOutput);
// Collapse batch and feature dimensions
THCDeviceTensor<float, 4> cudaGradInput;
THCDeviceTensor<float, 4> cudaGradOutput;
if (THCudaTensor_nDimension(state, input) == 4)
{
cudaGradInput = toDeviceTensor<float, 4>(state, gradInput);
cudaGradOutput = toDeviceTensor<float, 4>(state, gradOutput);
}
else
{
cudaGradInput =
toDeviceTensor<float, 5>(state, gradInput).downcastOuter<4>();
cudaGradOutput =
toDeviceTensor<float, 5>(state, gradOutput).downcastOuter<4>();
}
THLongStorage *indicesSize = THLongStorage_newWithSize(4);
long indicesSizeRaw[4] = { batchSize * inputSlices,
outputTime, outputHeight, outputWidth };
THLongStorage_rawCopy(indicesSize, indicesSizeRaw);
THCudaTensor *indices1 = THCudaTensor_newWithStorage(
state, THCudaTensor_storage(state, indices),
THCudaTensor_storageOffset(state, indices), indicesSize, NULL);
THLongStorage_free(indicesSize);
THCDeviceTensor<float, 4> cudaIndices =
toDeviceTensor<float, 4>(state, indices1);
int totalZ = outputTime * inputSlices * batchSize;
int offsetZ = 0;
dim3 block(32, 8);
while (totalZ > 0) {
dim3 grid(THCCeilDiv(outputWidth, static_cast<int>(block.x)),
THCCeilDiv(outputHeight, static_cast<int>(block.y)),
totalZ > 65535 ? 65535 : totalZ);
cuda_VolumetricMaxPooling_updateGradInput<<<grid, block,
0, THCState_getCurrentStream(state)>>>(
cudaGradOutput,
cudaIndices,
cudaGradInput,
dT, dH, dW,
padT, padH, padW, offsetZ);
THCudaCheck(cudaGetLastError());
totalZ -= 65535;
offsetZ += 65535;
}
// cleanup
THCudaTensor_free(state, gradOutput);
THCudaTensor_free(state, indices1);
}

View File

@ -0,0 +1,196 @@
# Synopsis:
# CUDA_SELECT_NVCC_ARCH_FLAGS(out_variable [target_CUDA_architectures])
# -- Selects GPU arch flags for nvcc based on target_CUDA_architectures
# target_CUDA_architectures : Auto | Common | All | LIST(ARCH_AND_PTX ...)
# - "Auto" detects local machine GPU compute arch at runtime.
# - "Common" and "All" cover common and entire subsets of architectures
# ARCH_AND_PTX : NAME | NUM.NUM | NUM.NUM(NUM.NUM) | NUM.NUM+PTX
# NAME: Fermi Kepler Maxwell Kepler+Tegra Kepler+Tesla Maxwell+Tegra Pascal
# NUM: Any number. Only those pairs are currently accepted by NVCC though:
# 2.0 2.1 3.0 3.2 3.5 3.7 5.0 5.2 5.3 6.0 6.2
# Returns LIST of flags to be added to CUDA_NVCC_FLAGS in ${out_variable}
# Additionally, sets ${out_variable}_readable to the resulting numeric list
# Example:
# CUDA_SELECT_NVCC_ARCH_FLAGS(ARCH_FLAGS 3.0 3.5+PTX 5.2(5.0) Maxwell)
# LIST(APPEND CUDA_NVCC_FLAGS ${ARCH_FLAGS})
#
# More info on CUDA architectures: https://en.wikipedia.org/wiki/CUDA
#
# This list will be used for CUDA_ARCH_NAME = All option
set(CUDA_KNOWN_GPU_ARCHITECTURES "Fermi" "Kepler" "Maxwell")
# This list will be used for CUDA_ARCH_NAME = Common option (enabled by default)
set(CUDA_COMMON_GPU_ARCHITECTURES "3.0" "3.5" "5.0")
if (CUDA_VERSION VERSION_GREATER "6.5")
list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Kepler+Tegra" "Kepler+Tesla" "Maxwell+Tegra")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "5.2")
endif ()
if (CUDA_VERSION VERSION_GREATER "7.5")
list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Pascal")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "6.0" "6.1" "6.1+PTX")
else()
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "5.2+PTX")
endif ()
################################################################################################
# A function for automatic detection of GPUs installed (if autodetection is enabled)
# Usage:
# CUDA_DETECT_INSTALLED_GPUS(OUT_VARIABLE)
#
function(CUDA_DETECT_INSTALLED_GPUS OUT_VARIABLE)
if(NOT CUDA_GPU_DETECT_OUTPUT)
set(cufile ${PROJECT_BINARY_DIR}/detect_cuda_archs.cu)
file(WRITE ${cufile} ""
"#include <cstdio>\n"
"int main()\n"
"{\n"
" int count = 0;\n"
" if (cudaSuccess != cudaGetDeviceCount(&count)) return -1;\n"
" if (count == 0) return -1;\n"
" for (int device = 0; device < count; ++device)\n"
" {\n"
" cudaDeviceProp prop;\n"
" if (cudaSuccess == cudaGetDeviceProperties(&prop, device))\n"
" std::printf(\"%d.%d \", prop.major, prop.minor);\n"
" }\n"
" return 0;\n"
"}\n")
execute_process(COMMAND "${CUDA_NVCC_EXECUTABLE}" "--run" "${cufile}"
"-ccbin" ${CMAKE_CXX_COMPILER}
WORKING_DIRECTORY "${PROJECT_BINARY_DIR}/CMakeFiles/"
RESULT_VARIABLE nvcc_res OUTPUT_VARIABLE nvcc_out
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
if(nvcc_res EQUAL 0)
string(REPLACE "2.1" "2.1(2.0)" nvcc_out "${nvcc_out}")
set(CUDA_GPU_DETECT_OUTPUT ${nvcc_out} CACHE INTERNAL "Returned GPU architetures from detect_gpus tool" FORCE)
endif()
endif()
if(NOT CUDA_GPU_DETECT_OUTPUT)
message(STATUS "Automatic GPU detection failed. Building for common architectures.")
set(${OUT_VARIABLE} ${CUDA_COMMON_GPU_ARCHITECTURES} PARENT_SCOPE)
else()
set(${OUT_VARIABLE} ${CUDA_GPU_DETECT_OUTPUT} PARENT_SCOPE)
endif()
endfunction()
################################################################################################
# Function for selecting GPU arch flags for nvcc based on CUDA architectures from parameter list
# Usage:
# SELECT_NVCC_ARCH_FLAGS(out_variable [list of CUDA compute archs])
function(CUDA_SELECT_NVCC_ARCH_FLAGS out_variable)
set(CUDA_ARCH_LIST "${ARGN}")
if("X${CUDA_ARCH_LIST}" STREQUAL "X" )
set(CUDA_ARCH_LIST "Auto")
endif()
set(cuda_arch_bin)
set(cuda_arch_ptx)
if("${CUDA_ARCH_LIST}" STREQUAL "All")
set(CUDA_ARCH_LIST ${CUDA_KNOWN_GPU_ARCHITECTURES})
elseif("${CUDA_ARCH_LIST}" STREQUAL "Common")
set(CUDA_ARCH_LIST ${CUDA_COMMON_GPU_ARCHITECTURES})
elseif("${CUDA_ARCH_LIST}" STREQUAL "Auto")
CUDA_DETECT_INSTALLED_GPUS(CUDA_ARCH_LIST)
message(STATUS "Autodetected CUDA architecture(s): ${CUDA_ARCH_LIST}")
endif()
# Now process the list and look for names
string(REGEX REPLACE "[ \t]+" ";" CUDA_ARCH_LIST "${CUDA_ARCH_LIST}")
list(REMOVE_DUPLICATES CUDA_ARCH_LIST)
foreach(arch_name ${CUDA_ARCH_LIST})
set(arch_bin)
set(add_ptx FALSE)
# Check to see if we are compiling PTX
if(arch_name MATCHES "(.*)\\+PTX$")
set(add_ptx TRUE)
set(arch_name ${CMAKE_MATCH_1})
endif()
if(arch_name MATCHES "(^[0-9]\\.[0-9](\\([0-9]\\.[0-9]\\))?)$")
set(arch_bin ${CMAKE_MATCH_1})
set(arch_ptx ${arch_bin})
else()
# Look for it in our list of known architectures
if(${arch_name} STREQUAL "Fermi")
set(arch_bin "2.0 2.1(2.0)")
elseif(${arch_name} STREQUAL "Kepler+Tegra")
set(arch_bin 3.2)
elseif(${arch_name} STREQUAL "Kepler+Tesla")
set(arch_bin 3.7)
elseif(${arch_name} STREQUAL "Kepler")
set(arch_bin 3.0 3.5)
set(arch_ptx 3.5)
elseif(${arch_name} STREQUAL "Maxwell+Tegra")
set(arch_bin 5.3)
elseif(${arch_name} STREQUAL "Maxwell")
set(arch_bin 5.0 5.2)
set(arch_ptx 5.2)
elseif(${arch_name} STREQUAL "Pascal")
set(arch_bin 6.0 6.1)
set(arch_ptx 6.1)
else()
message(SEND_ERROR "Unknown CUDA Architecture Name ${arch_name} in CUDA_SELECT_NVCC_ARCH_FLAGS")
endif()
endif()
if(NOT arch_bin)
message(SEND_ERROR "arch_bin wasn't set for some reason")
endif()
list(APPEND cuda_arch_bin ${arch_bin})
if(add_ptx)
if (NOT arch_ptx)
set(arch_ptx ${arch_bin})
endif()
list(APPEND cuda_arch_ptx ${arch_ptx})
endif()
endforeach()
# remove dots and convert to lists
string(REGEX REPLACE "\\." "" cuda_arch_bin "${cuda_arch_bin}")
string(REGEX REPLACE "\\." "" cuda_arch_ptx "${cuda_arch_ptx}")
string(REGEX MATCHALL "[0-9()]+" cuda_arch_bin "${cuda_arch_bin}")
string(REGEX MATCHALL "[0-9]+" cuda_arch_ptx "${cuda_arch_ptx}")
if(cuda_arch_bin)
list(REMOVE_DUPLICATES cuda_arch_bin)
endif()
if(cuda_arch_ptx)
list(REMOVE_DUPLICATES cuda_arch_ptx)
endif()
set(nvcc_flags "")
set(nvcc_archs_readable "")
# Tell NVCC to add binaries for the specified GPUs
foreach(arch ${cuda_arch_bin})
if(arch MATCHES "([0-9]+)\\(([0-9]+)\\)")
# User explicitly specified ARCH for the concrete CODE
list(APPEND nvcc_flags -gencode arch=compute_${CMAKE_MATCH_2},code=sm_${CMAKE_MATCH_1})
list(APPEND nvcc_archs_readable sm_${CMAKE_MATCH_1})
else()
# User didn't explicitly specify ARCH for the concrete CODE, we assume ARCH=CODE
list(APPEND nvcc_flags -gencode arch=compute_${arch},code=sm_${arch})
list(APPEND nvcc_archs_readable sm_${arch})
endif()
endforeach()
# Tell NVCC to add PTX intermediate code for the specified architectures
foreach(arch ${cuda_arch_ptx})
list(APPEND nvcc_flags -gencode arch=compute_${arch},code=compute_${arch})
list(APPEND nvcc_archs_readable compute_${arch})
endforeach()
string(REPLACE ";" " " nvcc_archs_readable "${nvcc_archs_readable}")
set(${out_variable} ${nvcc_flags} PARENT_SCOPE)
set(${out_variable}_readable ${nvcc_archs_readable} PARENT_SCOPE)
endfunction()