mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Merge commit '73d15cf64320b4b77e7393efa1bf1e913404cfd6'
This commit is contained in:
commit
05fb544f23
154
torch/lib/THCUNN/BCECriterion.cu
Normal file
154
torch/lib/THCUNN/BCECriterion.cu
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
#include "THCUNN.h"
|
||||
#include "common.h"
|
||||
|
||||
#include <thrust/functional.h>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/iterator/zip_iterator.h>
|
||||
#include <thrust/transform.h>
|
||||
#include <thrust/transform_reduce.h>
|
||||
|
||||
const float eps = 1e-12f;
|
||||
|
||||
struct bce_functor
|
||||
{
|
||||
template <class Tuple>
|
||||
__host__ __device__
|
||||
float operator()(Tuple x)
|
||||
{
|
||||
float o = thrust::get<0>(x);
|
||||
float t = thrust::get<1>(x);
|
||||
return - (t * logf(o + eps) + (1.f - t) * logf(1.f - o + eps));
|
||||
}
|
||||
};
|
||||
|
||||
struct bce_functor_weights
|
||||
{
|
||||
template <class Tuple>
|
||||
__host__ __device__
|
||||
float operator()(Tuple x)
|
||||
{
|
||||
float o = thrust::get<0>(x);
|
||||
float t = thrust::get<1>(x);
|
||||
float w = thrust::get<2>(x);
|
||||
return - w * (t * logf(o + eps) + (1.f - t) * logf(1.f - o + eps));
|
||||
}
|
||||
};
|
||||
|
||||
void THNN_CudaBCECriterion_updateOutput(THCState *state, THCudaTensor *input, THCudaTensor *target, THCudaTensor *output, bool sizeAverage, THCudaTensor *weights)
|
||||
{
|
||||
THCUNN_assertSameGPU(state, 3, input, target, weights);
|
||||
|
||||
long size = THCudaTensor_nElement(state, input);
|
||||
|
||||
input = THCudaTensor_newContiguous(state, input);
|
||||
target = THCudaTensor_newContiguous(state, target);
|
||||
|
||||
thrust::device_ptr<float> input_data(THCudaTensor_data(state, input));
|
||||
thrust::device_ptr<float> target_data(THCudaTensor_data(state, target));
|
||||
|
||||
float sum;
|
||||
if (weights) {
|
||||
weights = THCudaTensor_newContiguous(state, weights);
|
||||
thrust::device_ptr<float> weights_data(THCudaTensor_data(state, weights));
|
||||
sum = thrust::transform_reduce(
|
||||
thrust::make_zip_iterator(thrust::make_tuple(input_data, target_data, weights_data)),
|
||||
thrust::make_zip_iterator(thrust::make_tuple(input_data+size, target_data+size, weights_data+size)),
|
||||
bce_functor_weights(),
|
||||
(float) 0.f,
|
||||
thrust::plus<float>()
|
||||
);
|
||||
THCudaTensor_free(state, weights);
|
||||
} else {
|
||||
sum = thrust::transform_reduce(
|
||||
thrust::make_zip_iterator(thrust::make_tuple(input_data, target_data)),
|
||||
thrust::make_zip_iterator(thrust::make_tuple(input_data+size, target_data+size)),
|
||||
bce_functor(),
|
||||
(float) 0.f,
|
||||
thrust::plus<float>()
|
||||
);
|
||||
}
|
||||
|
||||
if (sizeAverage)
|
||||
sum /= size;
|
||||
|
||||
THCudaTensor_free(state, input);
|
||||
THCudaTensor_free(state, target);
|
||||
|
||||
THCudaTensor_set1d(state, output, 0, sum);
|
||||
}
|
||||
|
||||
struct bce_updateGradInput_functor
|
||||
{
|
||||
const float norm;
|
||||
|
||||
bce_updateGradInput_functor(float norm_)
|
||||
: norm(norm_)
|
||||
{}
|
||||
|
||||
template <class Tuple>
|
||||
__host__ __device__
|
||||
float operator()(Tuple x)
|
||||
{
|
||||
float o = thrust::get<0>(x);
|
||||
float t = thrust::get<1>(x);
|
||||
return - (t - o) / ((1 - o + eps) * (o + eps)) * norm;
|
||||
}
|
||||
};
|
||||
|
||||
struct bce_updateGradInput_functor_weights
|
||||
{
|
||||
const float norm;
|
||||
|
||||
bce_updateGradInput_functor_weights(float norm_)
|
||||
: norm(norm_)
|
||||
{}
|
||||
|
||||
template <class Tuple>
|
||||
__host__ __device__
|
||||
float operator()(Tuple x)
|
||||
{
|
||||
float o = thrust::get<0>(x);
|
||||
float t = thrust::get<1>(x);
|
||||
float w = thrust::get<2>(x);
|
||||
return - (t - o) / ((1 - o + eps) * (o + eps)) * norm * w;
|
||||
}
|
||||
};
|
||||
|
||||
void THNN_CudaBCECriterion_updateGradInput(THCState *state, THCudaTensor *input, THCudaTensor *target, THCudaTensor *gradInput, bool sizeAverage, THCudaTensor *weights)
|
||||
{
|
||||
THCUNN_assertSameGPU(state, 4, input, target, gradInput, weights);
|
||||
|
||||
long size = THCudaTensor_nElement(state, input);
|
||||
float norm = (sizeAverage ? 1./size : 1.);
|
||||
|
||||
input = THCudaTensor_newContiguous(state, input);
|
||||
target = THCudaTensor_newContiguous(state, target);
|
||||
|
||||
THCudaTensor_resizeAs(state, gradInput, input);
|
||||
|
||||
thrust::device_ptr<float> input_data(THCudaTensor_data(state, input));
|
||||
thrust::device_ptr<float> target_data(THCudaTensor_data(state, target));
|
||||
thrust::device_ptr<float> gradInput_data(THCudaTensor_data(state, gradInput));
|
||||
|
||||
if (weights) {
|
||||
weights = THCudaTensor_newContiguous(state, weights);
|
||||
thrust::device_ptr<float> weights_data(THCudaTensor_data(state, weights));
|
||||
thrust::transform(
|
||||
thrust::make_zip_iterator(thrust::make_tuple(input_data, target_data, weights_data)),
|
||||
thrust::make_zip_iterator(thrust::make_tuple(input_data+size, target_data+size, weights_data+size)),
|
||||
gradInput_data,
|
||||
bce_updateGradInput_functor_weights(norm)
|
||||
);
|
||||
THCudaTensor_free(state, weights);
|
||||
} else {
|
||||
thrust::transform(
|
||||
thrust::make_zip_iterator(thrust::make_tuple(input_data, target_data)),
|
||||
thrust::make_zip_iterator(thrust::make_tuple(input_data+size, target_data+size)),
|
||||
gradInput_data,
|
||||
bce_updateGradInput_functor(norm)
|
||||
);
|
||||
}
|
||||
|
||||
THCudaTensor_free(state, input);
|
||||
THCudaTensor_free(state, target);
|
||||
}
|
||||
|
|
@ -9,6 +9,13 @@ IF(NOT CUDA_FOUND)
|
|||
FIND_PACKAGE(CUDA 6.5 REQUIRED)
|
||||
ENDIF()
|
||||
|
||||
# Detect CUDA architecture and get best NVCC flags
|
||||
IF(NOT COMMAND CUDA_SELECT_NVCC_ARCH_FLAGS)
|
||||
INCLUDE(${CMAKE_CURRENT_SOURCE_DIR}/cmake/select_compute_arch.cmake)
|
||||
ENDIF()
|
||||
CUDA_SELECT_NVCC_ARCH_FLAGS(NVCC_FLAGS_EXTRA $ENV{TORCH_CUDA_ARCH_LIST})
|
||||
LIST(APPEND CUDA_NVCC_FLAGS ${NVCC_FLAGS_EXTRA})
|
||||
|
||||
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
|
||||
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "4.9.3")
|
||||
if(CUDA_VERSION VERSION_LESS "8.0")
|
||||
|
|
|
|||
|
|
@ -89,15 +89,15 @@ void THNN_CudaHardTanh_updateGradInput(
|
|||
THCUNN_assertSameGPU(state, 3, input, gradOutput, gradInput);
|
||||
|
||||
if (inplace)
|
||||
{
|
||||
THCudaTensor_resizeAs(state, gradInput, input);
|
||||
THC_pointwiseApply3(state, gradInput, input, gradOutput,
|
||||
hardtanhupdateGradInput_functor(min_val, max_val));
|
||||
}
|
||||
else
|
||||
{
|
||||
THCudaTensor_set(state, gradInput, gradOutput);
|
||||
THC_pointwiseApply2(state, gradInput, input,
|
||||
hardtanhupdateGradInput_functor(min_val, max_val));
|
||||
}
|
||||
else
|
||||
{
|
||||
THCudaTensor_resizeAs(state, gradInput, input);
|
||||
THC_pointwiseApply3(state, gradInput, input, gradOutput,
|
||||
hardtanhupdateGradInput_functor(min_val, max_val));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ void THNN_CudaSpatialConvolutionMM_updateOutput(THCState *state, THCudaTensor *i
|
|||
if (weight->nDimension == 4) {
|
||||
long s1 = weight->size[0];
|
||||
long s2 = weight->size[1] * weight->size[2] * weight->size[3];
|
||||
weight = THCudaTensor_newWithStorage2d(state, weight->storage, 0, s1, -1, s2, -1);
|
||||
weight = THCudaTensor_newWithStorage2d(state, weight->storage, weight->storageOffset, s1, -1, s2, -1);
|
||||
freeWeight = 1;
|
||||
}
|
||||
|
||||
|
|
@ -155,7 +155,7 @@ void THNN_CudaSpatialConvolutionMM_updateGradInput(THCState *state, THCudaTensor
|
|||
if (weight->nDimension == 4) {
|
||||
long s1 = weight->size[0];
|
||||
long s2 = weight->size[1] * weight->size[2] * weight->size[3];
|
||||
weight = THCudaTensor_newWithStorage2d(state, weight->storage, 0, s1, -1, s2, -1);
|
||||
weight = THCudaTensor_newWithStorage2d(state, weight->storage, weight->storageOffset, s1, -1, s2, -1);
|
||||
freeWeight = 1;
|
||||
}
|
||||
|
||||
|
|
@ -252,7 +252,7 @@ void THNN_CudaSpatialConvolutionMM_accGradParameters(THCState *state, THCudaTens
|
|||
if (gradWeight->nDimension == 4) {
|
||||
long s1 = gradWeight->size[0];
|
||||
long s2 = gradWeight->size[1] * gradWeight->size[2] * gradWeight->size[3];
|
||||
gradWeight = THCudaTensor_newWithStorage2d(state, gradWeight->storage, 0, s1, -1, s2, -1);
|
||||
gradWeight = THCudaTensor_newWithStorage2d(state, gradWeight->storage, gradWeight->storageOffset, s1, -1, s2, -1);
|
||||
freeWeight = 1;
|
||||
}
|
||||
|
||||
|
|
|
|||
207
torch/lib/THCUNN/SpatialDilatedMaxPooling.cu
Normal file
207
torch/lib/THCUNN/SpatialDilatedMaxPooling.cu
Normal file
|
|
@ -0,0 +1,207 @@
|
|||
#include "THCUNN.h"
|
||||
#include "common.h"
|
||||
|
||||
// kernels borrowed from Caffe
|
||||
template <typename Dtype>
|
||||
__global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data,
|
||||
const int num, const int channels, const int height,
|
||||
const int width, const int pooled_height, const int pooled_width,
|
||||
const int kernel_h, const int kernel_w, const int stride_h,
|
||||
const int stride_w, const int pad_h, const int pad_w,
|
||||
const int dilation_h, const int dilation_w, Dtype* top_data,
|
||||
Dtype* top_mask) {
|
||||
CUDA_KERNEL_LOOP(index, nthreads) {
|
||||
int pw = index % pooled_width;
|
||||
int ph = (index / pooled_width) % pooled_height;
|
||||
int c = (index / pooled_width / pooled_height) % channels;
|
||||
int n = index / pooled_width / pooled_height / channels;
|
||||
int hstart = ph * stride_h - pad_h;
|
||||
int wstart = pw * stride_w - pad_w;
|
||||
int hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height);
|
||||
int wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width);
|
||||
while(hstart < 0)
|
||||
hstart += dilation_h;
|
||||
while(wstart < 0)
|
||||
wstart += dilation_w;
|
||||
Dtype maxval = -FLT_MAX;
|
||||
int maxidx = -1;
|
||||
bottom_data += (n * channels + c) * height * width;
|
||||
for (int h = hstart; h < hend; h += dilation_h) {
|
||||
for (int w = wstart; w < wend; w += dilation_w) {
|
||||
if (bottom_data[h * width + w] > maxval) {
|
||||
maxidx = h * width + w;
|
||||
maxval = bottom_data[maxidx];
|
||||
}
|
||||
}
|
||||
}
|
||||
top_data[index] = maxval;
|
||||
top_mask[index] = maxidx + TH_INDEX_BASE;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename Dtype>
|
||||
__global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
|
||||
const Dtype* top_mask, const int num, const int channels,
|
||||
const int height, const int width, const int pooled_height,
|
||||
const int pooled_width, const int kernel_h, const int kernel_w,
|
||||
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
Dtype* bottom_diff) {
|
||||
CUDA_KERNEL_LOOP(index, nthreads) {
|
||||
// find out the local index
|
||||
// find out the local offset
|
||||
int w = index % width;
|
||||
int h = (index / width) % height;
|
||||
int c = (index / width / height) % channels;
|
||||
int n = index / width / height / channels;
|
||||
int phstart =
|
||||
(h + pad_h < ((kernel_h - 1) * dilation_h + 1)) ? 0 : (h + pad_h - ((kernel_h - 1) * dilation_h + 1)) / stride_h + 1;
|
||||
int phend = min((h + pad_h) / stride_h + 1, pooled_height);
|
||||
int pwstart =
|
||||
(w + pad_w < ((kernel_w - 1) * dilation_w + 1)) ? 0 : (w + pad_w - ((kernel_w - 1) * dilation_w + 1)) / stride_w + 1;
|
||||
int pwend = min((w + pad_w) / stride_w + 1, pooled_width);
|
||||
|
||||
Dtype gradient = 0;
|
||||
int offset = (n * channels + c) * pooled_height * pooled_width;
|
||||
top_diff += offset;
|
||||
top_mask += offset;
|
||||
for (int ph = phstart; ph < phend; ++ph) {
|
||||
for (int pw = pwstart; pw < pwend; ++pw) {
|
||||
if (top_mask[ph * pooled_width + pw] - TH_INDEX_BASE == h * width + w) {
|
||||
gradient += top_diff[ph * pooled_width + pw];
|
||||
}
|
||||
}
|
||||
}
|
||||
bottom_diff[index] = gradient;
|
||||
}
|
||||
}
|
||||
|
||||
void THNN_CudaSpatialDilatedMaxPooling_updateOutput(THCState *state, THCudaTensor *input, THCudaTensor *output, THCudaTensor *indices, int kW, int kH, int dW, int dH, int padW, int padH, int dilationW, int dilationH, bool ceil_mode)
|
||||
{
|
||||
|
||||
THCUNN_assertSameGPU(state, 3, input, output, indices);
|
||||
THArgCheck(input->nDimension == 3 || input->nDimension == 4, 2, "3D or 4D (batch) tensor expected");
|
||||
|
||||
long nInputCols, nInputRows, nInputPlane, batchSize;
|
||||
long nOutputCols, nOutputRows;
|
||||
|
||||
if (input->nDimension == 3) {
|
||||
nInputCols = input->size[2];
|
||||
nInputRows = input->size[1];
|
||||
nInputPlane = input->size[0];
|
||||
batchSize = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
nInputCols = input->size[3];
|
||||
nInputRows = input->size[2];
|
||||
nInputPlane = input->size[1];
|
||||
batchSize = input->size[0];
|
||||
}
|
||||
|
||||
THArgCheck(nInputCols >= kW - padW && nInputRows >= kH - padH, 2, "input image smaller than kernel size");
|
||||
THArgCheck(kW/2 >= padW && kH/2 >= padH, 2, "pad should be smaller than half of kernel size");
|
||||
|
||||
if(ceil_mode) {
|
||||
nOutputCols = ceil(float(nInputCols - (dilationW * (kW - 1) + 1) + 2*padW) / float(dW)) + 1;
|
||||
nOutputRows = ceil(float(nInputRows - (dilationH * (kH - 1) + 1) + 2*padH) / float(dH)) + 1;
|
||||
}
|
||||
else {
|
||||
nOutputCols = floor(float(nInputCols - (dilationW * (kW - 1) + 1) + 2*padW) / float(dW)) + 1;
|
||||
nOutputRows = floor(float(nInputRows - (dilationH * (kH - 1) + 1) + 2*padH) / float(dH)) + 1;
|
||||
}
|
||||
|
||||
if (nOutputCols < 1 || nOutputRows < 1)
|
||||
THError("Given input size: (%dx%dx%d). Calculated output size: (%dx%dx%d). Output size is too small",
|
||||
nInputPlane,nInputRows,nInputCols,nInputPlane,nOutputRows,nOutputCols);
|
||||
|
||||
if (padW || padH)
|
||||
{
|
||||
// ensure that the last pooling starts inside the image
|
||||
if ((nOutputRows - 1)*dH >= nInputRows + padH)
|
||||
--nOutputRows;
|
||||
if ((nOutputCols - 1)*dW >= nInputCols + padW)
|
||||
--nOutputCols;
|
||||
}
|
||||
|
||||
input = THCudaTensor_newContiguous(state, input);
|
||||
float* input_data = THCudaTensor_data(state, input);
|
||||
|
||||
THCudaTensor_resize4d(state, output, batchSize, nInputPlane, nOutputRows, nOutputCols);
|
||||
THCudaTensor_resizeAs(state, indices, output);
|
||||
|
||||
float* indices_data = THCudaTensor_data(state, indices);
|
||||
float* output_data = THCudaTensor_data(state, output);
|
||||
|
||||
int count = THCudaTensor_nElement(state, output);
|
||||
|
||||
MaxPoolForward <<< GET_BLOCKS(count), CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>
|
||||
(count, input_data,
|
||||
batchSize, nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols,
|
||||
kH, kW, dH, dW, padH, padW, dilationH, dilationW, output_data, indices_data);
|
||||
THCudaCheck(cudaGetLastError());
|
||||
|
||||
if(input->nDimension == 3)
|
||||
THCudaTensor_resize3d(state, output, nInputPlane, nOutputRows, nOutputCols);
|
||||
|
||||
THCudaTensor_free(state, input);
|
||||
}
|
||||
|
||||
void THNN_CudaSpatialDilatedMaxPooling_updateGradInput(THCState *state, THCudaTensor *input, THCudaTensor *gradOutput, THCudaTensor *gradInput, THCudaTensor *indices, int kW, int kH, int dW, int dH, int padW, int padH, int dilationW, int dilationH, bool ceil_mode)
|
||||
{
|
||||
THCUNN_assertSameGPU(state, 4, input, gradOutput, indices, gradInput);
|
||||
|
||||
input = THCudaTensor_newContiguous(state, input);
|
||||
gradOutput = THCudaTensor_newContiguous(state, gradOutput);
|
||||
|
||||
long nInputCols, nInputRows, nInputPlane, batchSize;
|
||||
long nOutputCols, nOutputRows;
|
||||
|
||||
if (input->nDimension == 3) {
|
||||
nInputCols = input->size[2];
|
||||
nInputRows = input->size[1];
|
||||
nInputPlane = input->size[0];
|
||||
batchSize = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
nInputCols = input->size[3];
|
||||
nInputRows = input->size[2];
|
||||
nInputPlane = input->size[1];
|
||||
batchSize = input->size[0];
|
||||
}
|
||||
|
||||
if(ceil_mode) {
|
||||
nOutputCols = ceil(float(nInputCols - (dilationW * (kW - 1) + 1) + 2*padW) / float(dW)) + 1;
|
||||
nOutputRows = ceil(float(nInputRows - (dilationH * (kH - 1) + 1) + 2*padH) / float(dH)) + 1;
|
||||
}
|
||||
else {
|
||||
nOutputCols = floor(float(nInputCols - (dilationW * (kW - 1) + 1) + 2*padW) / float(dW)) + 1;
|
||||
nOutputRows = floor(float(nInputRows - (dilationH * (kH - 1) + 1) + 2*padH) / float(dH)) + 1;
|
||||
}
|
||||
|
||||
if (nOutputCols < 1 || nOutputRows < 1)
|
||||
THError("Given input size: (%dx%dx%d). Calculated output size: (%dx%dx%d). Output size is too small",
|
||||
nInputPlane,nInputRows,nInputCols,nInputPlane,nOutputRows,nOutputCols);
|
||||
|
||||
gradOutput = THCudaTensor_newContiguous(state, gradOutput);
|
||||
THCudaTensor_resizeAs(state, gradInput, input);
|
||||
|
||||
int count = THCudaTensor_nElement(state, input);
|
||||
|
||||
MaxPoolBackward <<< GET_BLOCKS(count), CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>
|
||||
(count,
|
||||
THCudaTensor_data(state, gradOutput),
|
||||
THCudaTensor_data(state, indices),
|
||||
batchSize, nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols,
|
||||
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
|
||||
THCudaTensor_data(state, gradInput));
|
||||
THCudaCheck(cudaGetLastError());
|
||||
|
||||
THCudaTensor_free(state, gradOutput);
|
||||
|
||||
// clean
|
||||
THCudaTensor_free(state, input);
|
||||
THCudaTensor_free(state, gradOutput);
|
||||
}
|
||||
|
|
@ -1,207 +1,18 @@
|
|||
#include "THCUNN.h"
|
||||
#include "common.h"
|
||||
|
||||
// kernels borrowed from Caffe
|
||||
template <typename Dtype>
|
||||
__global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data,
|
||||
const int num, const int channels, const int height,
|
||||
const int width, const int pooled_height, const int pooled_width,
|
||||
const int kernel_h, const int kernel_w, const int stride_h,
|
||||
const int stride_w, const int pad_h, const int pad_w,
|
||||
const int dilation_h, const int dilation_w, Dtype* top_data,
|
||||
Dtype* top_mask) {
|
||||
CUDA_KERNEL_LOOP(index, nthreads) {
|
||||
int pw = index % pooled_width;
|
||||
int ph = (index / pooled_width) % pooled_height;
|
||||
int c = (index / pooled_width / pooled_height) % channels;
|
||||
int n = index / pooled_width / pooled_height / channels;
|
||||
int hstart = ph * stride_h - pad_h;
|
||||
int wstart = pw * stride_w - pad_w;
|
||||
int hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height);
|
||||
int wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width);
|
||||
while(hstart < 0)
|
||||
hstart += dilation_h;
|
||||
while(wstart < 0)
|
||||
wstart += dilation_w;
|
||||
Dtype maxval = -FLT_MAX;
|
||||
int maxidx = -1;
|
||||
bottom_data += (n * channels + c) * height * width;
|
||||
for (int h = hstart; h < hend; h += dilation_h) {
|
||||
for (int w = wstart; w < wend; w += dilation_w) {
|
||||
if (bottom_data[h * width + w] > maxval) {
|
||||
maxidx = h * width + w;
|
||||
maxval = bottom_data[maxidx];
|
||||
}
|
||||
}
|
||||
}
|
||||
top_data[index] = maxval;
|
||||
top_mask[index] = maxidx + TH_INDEX_BASE;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename Dtype>
|
||||
__global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
|
||||
const Dtype* top_mask, const int num, const int channels,
|
||||
const int height, const int width, const int pooled_height,
|
||||
const int pooled_width, const int kernel_h, const int kernel_w,
|
||||
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
Dtype* bottom_diff) {
|
||||
CUDA_KERNEL_LOOP(index, nthreads) {
|
||||
// find out the local index
|
||||
// find out the local offset
|
||||
int w = index % width;
|
||||
int h = (index / width) % height;
|
||||
int c = (index / width / height) % channels;
|
||||
int n = index / width / height / channels;
|
||||
int phstart =
|
||||
(h + pad_h < ((kernel_h - 1) * dilation_h + 1)) ? 0 : (h + pad_h - ((kernel_h - 1) * dilation_h + 1)) / stride_h + 1;
|
||||
int phend = min((h + pad_h) / stride_h + 1, pooled_height);
|
||||
int pwstart =
|
||||
(w + pad_w < ((kernel_w - 1) * dilation_w + 1)) ? 0 : (w + pad_w - ((kernel_w - 1) * dilation_w + 1)) / stride_w + 1;
|
||||
int pwend = min((w + pad_w) / stride_w + 1, pooled_width);
|
||||
|
||||
Dtype gradient = 0;
|
||||
int offset = (n * channels + c) * pooled_height * pooled_width;
|
||||
top_diff += offset;
|
||||
top_mask += offset;
|
||||
for (int ph = phstart; ph < phend; ++ph) {
|
||||
for (int pw = pwstart; pw < pwend; ++pw) {
|
||||
if (top_mask[ph * pooled_width + pw] - TH_INDEX_BASE == h * width + w) {
|
||||
gradient += top_diff[ph * pooled_width + pw];
|
||||
}
|
||||
}
|
||||
}
|
||||
bottom_diff[index] = gradient;
|
||||
}
|
||||
}
|
||||
|
||||
void THNN_CudaSpatialMaxPooling_updateOutput(THCState *state, THCudaTensor *input, THCudaTensor *output, THCudaTensor *indices, int kW, int kH, int dW, int dH, int padW, int padH, int dilationW, int dilationH, bool ceil_mode)
|
||||
void THNN_CudaSpatialMaxPooling_updateOutput(THCState *state, THCudaTensor *input, THCudaTensor *output, THCudaTensor *indices, int kW, int kH, int dW, int dH, int padW, int padH, bool ceil_mode)
|
||||
{
|
||||
THNN_CudaSpatialDilatedMaxPooling_updateOutput(
|
||||
state, input, output, indices,
|
||||
kW, kH, dW, dH, padW, padH, 1, 1, ceil_mode);
|
||||
|
||||
THCUNN_assertSameGPU(state, 3, input, output, indices);
|
||||
THArgCheck(input->nDimension == 3 || input->nDimension == 4, 2, "3D or 4D (batch) tensor expected");
|
||||
|
||||
long nInputCols, nInputRows, nInputPlane, batchSize;
|
||||
long nOutputCols, nOutputRows;
|
||||
|
||||
if (input->nDimension == 3) {
|
||||
nInputCols = input->size[2];
|
||||
nInputRows = input->size[1];
|
||||
nInputPlane = input->size[0];
|
||||
batchSize = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
nInputCols = input->size[3];
|
||||
nInputRows = input->size[2];
|
||||
nInputPlane = input->size[1];
|
||||
batchSize = input->size[0];
|
||||
}
|
||||
|
||||
THArgCheck(nInputCols >= kW - padW && nInputRows >= kH - padH, 2, "input image smaller than kernel size");
|
||||
THArgCheck(kW/2 >= padW && kH/2 >= padH, 2, "pad should be smaller than half of kernel size");
|
||||
|
||||
if(ceil_mode) {
|
||||
nOutputCols = ceil(float(nInputCols - (dilationW * (kW - 1) + 1) + 2*padW) / float(dW)) + 1;
|
||||
nOutputRows = ceil(float(nInputRows - (dilationH * (kH - 1) + 1) + 2*padH) / float(dH)) + 1;
|
||||
}
|
||||
else {
|
||||
nOutputCols = floor(float(nInputCols - (dilationW * (kW - 1) + 1) + 2*padW) / float(dW)) + 1;
|
||||
nOutputRows = floor(float(nInputRows - (dilationH * (kH - 1) + 1) + 2*padH) / float(dH)) + 1;
|
||||
}
|
||||
|
||||
if (nOutputCols < 1 || nOutputRows < 1)
|
||||
THError("Given input size: (%dx%dx%d). Calculated output size: (%dx%dx%d). Output size is too small",
|
||||
nInputPlane,nInputRows,nInputCols,nInputPlane,nOutputRows,nOutputCols);
|
||||
|
||||
if (padW || padH)
|
||||
{
|
||||
// ensure that the last pooling starts inside the image
|
||||
if ((nOutputRows - 1)*dH >= nInputRows + padH)
|
||||
--nOutputRows;
|
||||
if ((nOutputCols - 1)*dW >= nInputCols + padW)
|
||||
--nOutputCols;
|
||||
}
|
||||
|
||||
input = THCudaTensor_newContiguous(state, input);
|
||||
float* input_data = THCudaTensor_data(state, input);
|
||||
|
||||
THCudaTensor_resize4d(state, output, batchSize, nInputPlane, nOutputRows, nOutputCols);
|
||||
THCudaTensor_resizeAs(state, indices, output);
|
||||
|
||||
float* indices_data = THCudaTensor_data(state, indices);
|
||||
float* output_data = THCudaTensor_data(state, output);
|
||||
|
||||
int count = THCudaTensor_nElement(state, output);
|
||||
|
||||
MaxPoolForward <<< GET_BLOCKS(count), CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>
|
||||
(count, input_data,
|
||||
batchSize, nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols,
|
||||
kH, kW, dH, dW, padH, padW, dilationH, dilationW, output_data, indices_data);
|
||||
THCudaCheck(cudaGetLastError());
|
||||
|
||||
if(input->nDimension == 3)
|
||||
THCudaTensor_resize3d(state, output, nInputPlane, nOutputRows, nOutputCols);
|
||||
|
||||
THCudaTensor_free(state, input);
|
||||
}
|
||||
|
||||
void THNN_CudaSpatialMaxPooling_updateGradInput(THCState *state, THCudaTensor *input, THCudaTensor *gradOutput, THCudaTensor *gradInput, THCudaTensor *indices, int kW, int kH, int dW, int dH, int padW, int padH, int dilationW, int dilationH, bool ceil_mode)
|
||||
void THNN_CudaSpatialMaxPooling_updateGradInput(THCState *state, THCudaTensor *input, THCudaTensor *gradOutput, THCudaTensor *gradInput, THCudaTensor *indices, int kW, int kH, int dW, int dH, int padW, int padH, bool ceil_mode)
|
||||
{
|
||||
THCUNN_assertSameGPU(state, 4, input, gradOutput, indices, gradInput);
|
||||
THNN_CudaSpatialDilatedMaxPooling_updateGradInput(
|
||||
state, input, gradOutput, gradInput, indices,
|
||||
kW, kH, dW, dH, padW, padH, 1, 1, ceil_mode);
|
||||
|
||||
input = THCudaTensor_newContiguous(state, input);
|
||||
gradOutput = THCudaTensor_newContiguous(state, gradOutput);
|
||||
|
||||
long nInputCols, nInputRows, nInputPlane, batchSize;
|
||||
long nOutputCols, nOutputRows;
|
||||
|
||||
if (input->nDimension == 3) {
|
||||
nInputCols = input->size[2];
|
||||
nInputRows = input->size[1];
|
||||
nInputPlane = input->size[0];
|
||||
batchSize = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
nInputCols = input->size[3];
|
||||
nInputRows = input->size[2];
|
||||
nInputPlane = input->size[1];
|
||||
batchSize = input->size[0];
|
||||
}
|
||||
|
||||
if(ceil_mode) {
|
||||
nOutputCols = ceil(float(nInputCols - (dilationW * (kW - 1) + 1) + 2*padW) / float(dW)) + 1;
|
||||
nOutputRows = ceil(float(nInputRows - (dilationH * (kH - 1) + 1) + 2*padH) / float(dH)) + 1;
|
||||
}
|
||||
else {
|
||||
nOutputCols = floor(float(nInputCols - (dilationW * (kW - 1) + 1) + 2*padW) / float(dW)) + 1;
|
||||
nOutputRows = floor(float(nInputRows - (dilationH * (kH - 1) + 1) + 2*padH) / float(dH)) + 1;
|
||||
}
|
||||
|
||||
if (nOutputCols < 1 || nOutputRows < 1)
|
||||
THError("Given input size: (%dx%dx%d). Calculated output size: (%dx%dx%d). Output size is too small",
|
||||
nInputPlane,nInputRows,nInputCols,nInputPlane,nOutputRows,nOutputCols);
|
||||
|
||||
gradOutput = THCudaTensor_newContiguous(state, gradOutput);
|
||||
THCudaTensor_resizeAs(state, gradInput, input);
|
||||
|
||||
int count = THCudaTensor_nElement(state, input);
|
||||
|
||||
MaxPoolBackward <<< GET_BLOCKS(count), CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>
|
||||
(count,
|
||||
THCudaTensor_data(state, gradOutput),
|
||||
THCudaTensor_data(state, indices),
|
||||
batchSize, nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols,
|
||||
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
|
||||
THCudaTensor_data(state, gradInput));
|
||||
THCudaCheck(cudaGetLastError());
|
||||
|
||||
THCudaTensor_free(state, gradOutput);
|
||||
|
||||
// clean
|
||||
THCudaTensor_free(state, input);
|
||||
THCudaTensor_free(state, gradOutput);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,6 +27,21 @@ TH_API void THNN_CudaAbsCriterion_updateGradInput(
|
|||
THCudaTensor *gradInput,
|
||||
bool sizeAverage);
|
||||
|
||||
TH_API void THNN_CudaBCECriterion_updateOutput(
|
||||
THCState *state,
|
||||
THCudaTensor *input,
|
||||
THCudaTensor *target,
|
||||
THCudaTensor *output,
|
||||
bool sizeAverage,
|
||||
THCudaTensor *weights);
|
||||
TH_API void THNN_CudaBCECriterion_updateGradInput(
|
||||
THCState *state,
|
||||
THCudaTensor *input,
|
||||
THCudaTensor *target,
|
||||
THCudaTensor *gradInput,
|
||||
bool sizeAverage,
|
||||
THCudaTensor *weights);
|
||||
|
||||
TH_API void THNN_CudaClassNLLCriterion_updateOutput(
|
||||
THCState *state,
|
||||
THCudaTensor *input,
|
||||
|
|
@ -735,9 +750,29 @@ TH_API void THNN_CudaSpatialMaxPooling_updateOutput(
|
|||
int kW, int kH,
|
||||
int dW, int dH,
|
||||
int padW, int padH,
|
||||
int dilationW, int dilationH,
|
||||
bool ceil_mode);
|
||||
TH_API void THNN_CudaSpatialMaxPooling_updateGradInput(
|
||||
THCState *state,
|
||||
THCudaTensor *input,
|
||||
THCudaTensor *gradOutput,
|
||||
THCudaTensor *gradInput,
|
||||
THCudaTensor *indices,
|
||||
int kW, int kH,
|
||||
int dW, int dH,
|
||||
int padW, int padH,
|
||||
bool ceil_mode);
|
||||
|
||||
TH_API void THNN_CudaSpatialDilatedMaxPooling_updateOutput(
|
||||
THCState *state,
|
||||
THCudaTensor *input,
|
||||
THCudaTensor *output,
|
||||
THCudaTensor *indices,
|
||||
int kW, int kH,
|
||||
int dW, int dH,
|
||||
int padW, int padH,
|
||||
int dilationW, int dilationH,
|
||||
bool ceil_mode);
|
||||
TH_API void THNN_CudaSpatialDilatedMaxPooling_updateGradInput(
|
||||
THCState *state,
|
||||
THCudaTensor *input,
|
||||
THCudaTensor *gradOutput,
|
||||
|
|
@ -964,6 +999,26 @@ TH_API void THNN_CudaVolumetricMaxPooling_updateGradInput(
|
|||
int dT, int dW, int dH,
|
||||
int padT, int padW, int padH);
|
||||
|
||||
TH_API void THNN_CudaVolumetricDilatedMaxPooling_updateOutput(
|
||||
THCState *state,
|
||||
THCudaTensor *input,
|
||||
THCudaTensor *output,
|
||||
THCudaTensor *indices,
|
||||
int kT, int kW, int kH,
|
||||
int dT, int dW, int dH,
|
||||
int padT, int padW, int padH,
|
||||
int dilationT, int dilationW, int dilationH,
|
||||
bool ceilMode);
|
||||
TH_API void THNN_CudaVolumetricDilatedMaxPooling_updateGradInput(
|
||||
THCState *state,
|
||||
THCudaTensor *input,
|
||||
THCudaTensor *gradOutput,
|
||||
THCudaTensor *gradInput,
|
||||
THCudaTensor *indices,
|
||||
int dT, int dW, int dH,
|
||||
int padT, int padW, int padH,
|
||||
int dilationT, int dilationW, int dilationH);
|
||||
|
||||
TH_API void THNN_CudaVolumetricMaxUnpooling_updateOutput(
|
||||
THCState *state,
|
||||
THCudaTensor *input,
|
||||
|
|
|
|||
437
torch/lib/THCUNN/VolumetricDilatedMaxPooling.cu
Normal file
437
torch/lib/THCUNN/VolumetricDilatedMaxPooling.cu
Normal file
|
|
@ -0,0 +1,437 @@
|
|||
#include "THCUNN.h"
|
||||
#include "common.h"
|
||||
#include "THCDeviceTensor.cuh"
|
||||
#include "THCDeviceTensorUtils.cuh"
|
||||
#include "THCDeviceUtils.cuh"
|
||||
|
||||
#include <cfloat>
|
||||
|
||||
__global__ void cuda_VolumetricDilatedMaxPooling_updateOutput(
|
||||
THCDeviceTensor<float, 4> input,
|
||||
THCDeviceTensor<float, 4> indices,
|
||||
THCDeviceTensor<float, 4> output,
|
||||
int kT, int kH, int kW,
|
||||
int dT, int dH, int dW,
|
||||
int padT, int padH, int padW,
|
||||
int dilationT, int dilationH, int dilationW,
|
||||
int offsetZ)
|
||||
{
|
||||
int oColumn = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int oRow = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
int oFrame = (blockIdx.z + offsetZ) % output.getSize(1); // output frame/time
|
||||
int slice = (blockIdx.z + offsetZ) / output.getSize(1); // output slice/feature
|
||||
|
||||
if (oRow < output.getSize(2) && oColumn < output.getSize(3))
|
||||
{
|
||||
int iColumn = oColumn * dW - padW;
|
||||
int iRow = oRow * dH - padH;
|
||||
int iFrame = oFrame * dT - padT;
|
||||
|
||||
int maxColumn = 0;
|
||||
int maxRow = 0;
|
||||
int maxFrame = 0;
|
||||
|
||||
float max = -FLT_MAX;
|
||||
|
||||
for (int frame = 0; frame < kT; ++frame)
|
||||
{
|
||||
if (iFrame + frame * dilationT < input.getSize(1) && iFrame + frame * dilationT >= 0)
|
||||
{
|
||||
for (int row = 0; row < kH; ++row)
|
||||
{
|
||||
if (iRow + row * dilationH < input.getSize(2) && iRow + row * dilationH >= 0)
|
||||
{
|
||||
for (int column = 0; column < kW; ++column)
|
||||
{
|
||||
if (iColumn + column * dilationW < input.getSize(3) && iColumn + column * dilationW >= 0)
|
||||
{
|
||||
float val = input[slice][iFrame + frame * dilationT][iRow + row * dilationH][iColumn + column * dilationW];
|
||||
|
||||
if (max < val)
|
||||
{
|
||||
max = val;
|
||||
maxColumn = column;
|
||||
maxRow = row;
|
||||
maxFrame = frame;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output[slice][oFrame][oRow][oColumn] = max;
|
||||
float *idx = &indices[slice][oFrame][oRow][oColumn];
|
||||
((unsigned char*)(idx))[0] = maxFrame;
|
||||
((unsigned char*)(idx))[1] = maxRow;
|
||||
((unsigned char*)(idx))[2] = maxColumn;
|
||||
((unsigned char*)(idx))[3] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <int KERNEL_WIDTH>
|
||||
__global__ void cuda_VolumetricDilatedMaxPooling_updateOutput(
|
||||
THCDeviceTensor<float, 4> input, THCDeviceTensor<float, 4> indices,
|
||||
THCDeviceTensor<float, 4> output,
|
||||
int kT, int kH,
|
||||
int dT, int dH, int dW,
|
||||
int padT, int padH, int padW,
|
||||
int dilationT, int dilationH, int dilationW,
|
||||
int offsetZ)
|
||||
{
|
||||
int oColumn = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int oRow = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
int oFrame = (blockIdx.z + offsetZ) % output.getSize(1); // output frame/time
|
||||
int slice = (blockIdx.z + offsetZ) / output.getSize(1); // output slice/feature
|
||||
|
||||
if (oRow < output.getSize(2) && oColumn < output.getSize(3))
|
||||
{
|
||||
int iColumn = oColumn * dW - padW;
|
||||
int iRow = oRow * dH - padH;
|
||||
int iFrame = oFrame * dT - padT;
|
||||
|
||||
int maxColumn = 0;
|
||||
int maxRow = 0;
|
||||
int maxFrame;
|
||||
|
||||
float max = -FLT_MAX;
|
||||
|
||||
for (int frame = 0; frame < kT; ++frame)
|
||||
{
|
||||
if (iFrame + frame * dilationT < input.getSize(1) && iFrame + frame * dilationT >= 0)
|
||||
{
|
||||
for (int row = 0; row < kH; ++row)
|
||||
{
|
||||
if (iRow + row * dilationH < input.getSize(2) && iRow + row * dilationH >= 0)
|
||||
{
|
||||
for (int column = 0; column < KERNEL_WIDTH; ++column)
|
||||
{
|
||||
if (iColumn + column * dilationW < input.getSize(3) && iColumn + column * dilationW >= 0)
|
||||
{
|
||||
float val = input[slice][iFrame + frame * dilationT][iRow + row * dilationH][iColumn + column * dilationW];
|
||||
|
||||
if (max < val)
|
||||
{
|
||||
max = val;
|
||||
maxColumn = column;
|
||||
maxRow = row;
|
||||
maxFrame = frame;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output[slice][oFrame][oRow][oColumn] = max;
|
||||
float *idx = &indices[slice][oFrame][oRow][oColumn];
|
||||
((unsigned char*)(idx))[0] = maxFrame;
|
||||
((unsigned char*)(idx))[1] = maxRow;
|
||||
((unsigned char*)(idx))[2] = maxColumn;
|
||||
((unsigned char*)(idx))[3] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
#define UPDATE_OUTPUT_KERNEL_WIDTH(KW) case KW: \
|
||||
cuda_VolumetricDilatedMaxPooling_updateOutput<KW><<<grid, block, \
|
||||
0, THCState_getCurrentStream(state)>>>( \
|
||||
cudaInput, cudaIndices, cudaOutput, kT, kH, dT, dH, dW, padT, padH, padW,\
|
||||
dilationT, dilationH, dilationW, offsetZ); \
|
||||
break
|
||||
|
||||
|
||||
void THNN_CudaVolumetricDilatedMaxPooling_updateOutput(
|
||||
THCState *state, THCudaTensor *input, THCudaTensor *output, THCudaTensor *indices,
|
||||
int kT, int kW, int kH,
|
||||
int dT, int dW, int dH,
|
||||
int padT, int padW, int padH,
|
||||
int dilationT, int dilationW, int dilationH,
|
||||
bool ceilMode)
|
||||
{
|
||||
int batchSize;
|
||||
int inputSlices;
|
||||
int inputTime;
|
||||
int inputHeight;
|
||||
int inputWidth;
|
||||
int outputTime;
|
||||
int outputHeight;
|
||||
int outputWidth;
|
||||
|
||||
THCUNN_assertSameGPU(state, 3, input, indices, output);
|
||||
|
||||
if (THCudaTensor_nDimension(state, input) == 4)
|
||||
{
|
||||
THArgCheck(
|
||||
THCudaTensor_size(state, input, 1) >= kT &&
|
||||
THCudaTensor_size(state, input, 2) >= kH &&
|
||||
THCudaTensor_size(state, input, 3) >= kW, 2,
|
||||
"input image smaller than kernel size"
|
||||
);
|
||||
|
||||
/* sizes */
|
||||
batchSize = 1;
|
||||
inputSlices = THCudaTensor_size(state, input, 0);
|
||||
inputTime = THCudaTensor_size(state, input, 1);
|
||||
inputHeight = THCudaTensor_size(state, input, 2);
|
||||
inputWidth = THCudaTensor_size(state, input, 3);
|
||||
}
|
||||
else if (THCudaTensor_nDimension(state, input) == 5)
|
||||
{
|
||||
THArgCheck(
|
||||
THCudaTensor_size(state, input, 4) >= kW &&
|
||||
THCudaTensor_size(state, input, 3) >= kH &&
|
||||
THCudaTensor_size(state, input, 2) >= kT, 2,
|
||||
"input image smaller than kernel size"
|
||||
);
|
||||
|
||||
/* sizes */
|
||||
batchSize = THCudaTensor_size(state, input, 0);
|
||||
inputSlices = THCudaTensor_size(state, input, 1);
|
||||
inputTime = THCudaTensor_size(state, input, 2);
|
||||
inputHeight = THCudaTensor_size(state, input, 3);
|
||||
inputWidth = THCudaTensor_size(state, input, 4);
|
||||
}
|
||||
else
|
||||
{
|
||||
THArgCheck(false, 2, "4D or 5D tensor expected");
|
||||
}
|
||||
|
||||
THArgCheck(kT/2 >= padT && kW/2 >= padW && kH/2 >= padH, 2,
|
||||
"pad should be smaller than half of kernel size"
|
||||
);
|
||||
|
||||
if (ceilMode)
|
||||
{
|
||||
outputTime = (int)(ceil((float)(inputTime - (dilationT * (kT - 1) + 1) + 2*padT) / dT)) + 1;
|
||||
outputHeight = (int)(ceil((float)(inputHeight - (dilationH * (kH - 1) + 1) + 2*padH) / dH)) + 1;
|
||||
outputWidth = (int)(ceil((float)(inputWidth - (dilationW * (kW - 1) + 1) + 2*padW) / dW)) + 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
outputTime = (int)(floor((float)(inputTime - (dilationT * (kT - 1) + 1) + 2*padT) / dT)) + 1;
|
||||
outputHeight = (int)(floor((float)(inputHeight - (dilationH * (kH - 1) + 1) + 2*padH) / dH)) + 1;
|
||||
outputWidth = (int)(floor((float)(inputWidth - (dilationW * (kW - 1) + 1) + 2*padW) / dW)) + 1;
|
||||
}
|
||||
|
||||
if (outputTime < 1 || outputHeight < 1 || outputWidth < 1)
|
||||
THError("Given input size: (%dx%dx%dx%d). Calculated output size: (%dx%dx%dx%d). Output size is too small",
|
||||
inputSlices,inputTime,inputHeight,inputWidth,inputSlices,outputTime,outputHeight,outputWidth);
|
||||
|
||||
if (padT || padW || padH)
|
||||
{
|
||||
if ((outputTime - 1)*dT >= inputTime + padT)
|
||||
--outputTime;
|
||||
if ((outputHeight - 1)*dH >= inputHeight + padH)
|
||||
--outputHeight;
|
||||
if ((outputWidth - 1)*dW >= inputWidth + padW)
|
||||
--outputWidth;
|
||||
}
|
||||
|
||||
if (input->nDimension == 4) /* 4D */
|
||||
{
|
||||
/* resize output */
|
||||
THCudaTensor_resize4d(state, output, inputSlices,
|
||||
outputTime, outputHeight, outputWidth);
|
||||
/* indices pack ti,i,j locations for each output point as uchar into
|
||||
each float of the tensor */
|
||||
THCudaTensor_resize4d(state, indices, inputSlices,
|
||||
outputTime, outputHeight, outputWidth);
|
||||
}
|
||||
else
|
||||
{ /* 5D */
|
||||
THCudaTensor_resize5d(state, output, batchSize, inputSlices,
|
||||
outputTime, outputHeight, outputWidth);
|
||||
// Index tensor packs index offsets as uchars into floats
|
||||
THCudaTensor_resize5d(state, indices, batchSize, inputSlices,
|
||||
outputTime, outputHeight, outputWidth);
|
||||
}
|
||||
|
||||
input = THCudaTensor_newContiguous(state, input);
|
||||
|
||||
// Collapse batch and feature dimensions
|
||||
THCDeviceTensor<float, 4> cudaInput;
|
||||
THCDeviceTensor<float, 4> cudaOutput;
|
||||
if (THCudaTensor_nDimension(state, input) == 4)
|
||||
{
|
||||
cudaInput = toDeviceTensor<float, 4>(state, input);
|
||||
cudaOutput = toDeviceTensor<float, 4>(state, output);
|
||||
}
|
||||
else
|
||||
{
|
||||
cudaInput = toDeviceTensor<float, 5>(state, input).downcastOuter<4>();
|
||||
cudaOutput = toDeviceTensor<float, 5>(state, output).downcastOuter<4>();
|
||||
}
|
||||
|
||||
THLongStorage *indicesSize = THLongStorage_newWithSize(4);
|
||||
long indicesSizeRaw[4] = { batchSize * inputSlices,
|
||||
outputTime, outputHeight, outputWidth };
|
||||
THLongStorage_rawCopy(indicesSize, indicesSizeRaw);
|
||||
|
||||
THCudaTensor *indices1 = THCudaTensor_newWithStorage(
|
||||
state, THCudaTensor_storage(state, indices),
|
||||
THCudaTensor_storageOffset(state, indices),
|
||||
indicesSize, NULL);
|
||||
|
||||
THLongStorage_free(indicesSize);
|
||||
|
||||
THCDeviceTensor<float, 4> cudaIndices =
|
||||
toDeviceTensor<float, 4>(state, indices1);
|
||||
|
||||
int totalZ = outputTime * inputSlices * batchSize;
|
||||
int offsetZ = 0;
|
||||
dim3 block(32, 8);
|
||||
|
||||
while (totalZ > 0) {
|
||||
dim3 grid(THCCeilDiv(outputWidth, static_cast<int>(block.x)),
|
||||
THCCeilDiv(outputHeight, static_cast<int>(block.y)),
|
||||
totalZ > 65535 ? 65535 : totalZ);
|
||||
|
||||
switch (kW)
|
||||
{
|
||||
UPDATE_OUTPUT_KERNEL_WIDTH(1);
|
||||
UPDATE_OUTPUT_KERNEL_WIDTH(2);
|
||||
UPDATE_OUTPUT_KERNEL_WIDTH(3);
|
||||
UPDATE_OUTPUT_KERNEL_WIDTH(4);
|
||||
UPDATE_OUTPUT_KERNEL_WIDTH(5);
|
||||
UPDATE_OUTPUT_KERNEL_WIDTH(6);
|
||||
UPDATE_OUTPUT_KERNEL_WIDTH(7);
|
||||
default:
|
||||
cuda_VolumetricDilatedMaxPooling_updateOutput<<<grid, block,
|
||||
0, THCState_getCurrentStream(state)>>>(
|
||||
cudaInput, cudaIndices, cudaOutput,
|
||||
kT, kH, kW, dT, dH, dW,
|
||||
padT, padH, padW, dilationT, dilationH, dilationW, offsetZ);
|
||||
}
|
||||
THCudaCheck(cudaGetLastError());
|
||||
totalZ -= 65535;
|
||||
offsetZ += 65535;
|
||||
}
|
||||
|
||||
THCudaTensor_free(state, input);
|
||||
THCudaTensor_free(state, indices1);
|
||||
}
|
||||
|
||||
#undef UPDATE_OUTPUT_KERNEL_WIDTH
|
||||
|
||||
__global__ void cuda_VolumetricDilatedMaxPooling_updateGradInput(
|
||||
THCDeviceTensor<float, 4> gradOutput,
|
||||
THCDeviceTensor<float, 4> indices,
|
||||
THCDeviceTensor<float, 4> gradInput,
|
||||
int dT, int dH, int dW,
|
||||
int padT, int padH, int padW,
|
||||
int dilationT, int dilationH, int dilationW,
|
||||
int offsetZ)
|
||||
{
|
||||
int oColumn = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int oRow = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
int oFrame = (blockIdx.z + offsetZ) % gradOutput.getSize(1); // output frame/time
|
||||
int slice = (blockIdx.z + offsetZ) / gradOutput.getSize(1); // output slice/feature
|
||||
|
||||
if (oRow < gradOutput.getSize(2) && oColumn < gradOutput.getSize(3))
|
||||
{
|
||||
float *idx = &indices[slice][oFrame][oRow][oColumn];
|
||||
int iFrame = ((unsigned char*)(idx))[0] * dilationT + oFrame * dT - padT;
|
||||
int iRow = ((unsigned char*)(idx))[1] * dilationH + oRow * dH - padH;
|
||||
int iColumn = ((unsigned char*)(idx))[2] * dilationW + oColumn * dW - padW;
|
||||
atomicAdd(&gradInput[slice][iFrame][iRow][iColumn],
|
||||
gradOutput[slice][oFrame][oRow][oColumn]);
|
||||
}
|
||||
}
|
||||
|
||||
void THNN_CudaVolumetricDilatedMaxPooling_updateGradInput(
|
||||
THCState *state, THCudaTensor *input, THCudaTensor *gradOutput, THCudaTensor *gradInput,
|
||||
THCudaTensor *indices,
|
||||
int dT, int dW, int dH,
|
||||
int padT, int padW, int padH,
|
||||
int dilationT, int dilationW, int dilationH)
|
||||
{
|
||||
// Resize and initialize result tensor.
|
||||
THCudaTensor_resizeAs(state, gradInput, input);
|
||||
THCudaTensor_zero(state, gradInput);
|
||||
|
||||
int batchSize;
|
||||
int inputSlices;
|
||||
|
||||
int outputTime;
|
||||
int outputHeight;
|
||||
int outputWidth;
|
||||
|
||||
THCUNN_assertSameGPU(state, 4, input, indices, gradOutput, gradInput);
|
||||
|
||||
if (THCudaTensor_nDimension(state, input) == 4) /* 4D */
|
||||
{
|
||||
batchSize = 1;
|
||||
inputSlices = THCudaTensor_size(state, input, 0);
|
||||
|
||||
outputTime = THCudaTensor_size(state, gradOutput, 1);
|
||||
outputHeight = THCudaTensor_size(state, gradOutput, 2);
|
||||
outputWidth = THCudaTensor_size(state, gradOutput, 3);
|
||||
}
|
||||
else
|
||||
{
|
||||
batchSize = THCudaTensor_size(state, input, 0);
|
||||
inputSlices = THCudaTensor_size(state, input, 1);
|
||||
|
||||
outputTime = THCudaTensor_size(state, gradOutput, 2);
|
||||
outputHeight = THCudaTensor_size(state, gradOutput, 3);
|
||||
outputWidth = THCudaTensor_size(state, gradOutput, 4);
|
||||
}
|
||||
|
||||
gradOutput = THCudaTensor_newContiguous(state, gradOutput);
|
||||
|
||||
// Collapse batch and feature dimensions
|
||||
THCDeviceTensor<float, 4> cudaGradInput;
|
||||
THCDeviceTensor<float, 4> cudaGradOutput;
|
||||
if (THCudaTensor_nDimension(state, input) == 4)
|
||||
{
|
||||
cudaGradInput = toDeviceTensor<float, 4>(state, gradInput);
|
||||
cudaGradOutput = toDeviceTensor<float, 4>(state, gradOutput);
|
||||
}
|
||||
else
|
||||
{
|
||||
cudaGradInput =
|
||||
toDeviceTensor<float, 5>(state, gradInput).downcastOuter<4>();
|
||||
cudaGradOutput =
|
||||
toDeviceTensor<float, 5>(state, gradOutput).downcastOuter<4>();
|
||||
}
|
||||
|
||||
THLongStorage *indicesSize = THLongStorage_newWithSize(4);
|
||||
long indicesSizeRaw[4] = { batchSize * inputSlices,
|
||||
outputTime, outputHeight, outputWidth };
|
||||
THLongStorage_rawCopy(indicesSize, indicesSizeRaw);
|
||||
THCudaTensor *indices1 = THCudaTensor_newWithStorage(
|
||||
state, THCudaTensor_storage(state, indices),
|
||||
THCudaTensor_storageOffset(state, indices), indicesSize, NULL);
|
||||
THLongStorage_free(indicesSize);
|
||||
|
||||
THCDeviceTensor<float, 4> cudaIndices =
|
||||
toDeviceTensor<float, 4>(state, indices1);
|
||||
|
||||
int totalZ = outputTime * inputSlices * batchSize;
|
||||
int offsetZ = 0;
|
||||
dim3 block(32, 8);
|
||||
|
||||
while (totalZ > 0) {
|
||||
dim3 grid(THCCeilDiv(outputWidth, static_cast<int>(block.x)),
|
||||
THCCeilDiv(outputHeight, static_cast<int>(block.y)),
|
||||
totalZ > 65535 ? 65535 : totalZ);
|
||||
|
||||
cuda_VolumetricDilatedMaxPooling_updateGradInput<<<grid, block,
|
||||
0, THCState_getCurrentStream(state)>>>(
|
||||
cudaGradOutput,
|
||||
cudaIndices,
|
||||
cudaGradInput,
|
||||
dT, dH, dW,
|
||||
padT, padH, padW,
|
||||
dilationT, dilationH, dilationW, offsetZ);
|
||||
THCudaCheck(cudaGetLastError());
|
||||
totalZ -= 65535;
|
||||
offsetZ += 65535;
|
||||
}
|
||||
|
||||
// cleanup
|
||||
THCudaTensor_free(state, gradOutput);
|
||||
THCudaTensor_free(state, indices1);
|
||||
}
|
||||
|
|
@ -6,137 +6,6 @@
|
|||
|
||||
#include <cfloat>
|
||||
|
||||
__global__ void cuda_VolumetricMaxPooling_updateOutput(
|
||||
THCDeviceTensor<float, 4> input,
|
||||
THCDeviceTensor<float, 4> indices,
|
||||
THCDeviceTensor<float, 4> output,
|
||||
int kT, int kH, int kW,
|
||||
int dT, int dH, int dW,
|
||||
int padT, int padH, int padW, int offsetZ)
|
||||
{
|
||||
int oColumn = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int oRow = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
int oFrame = (blockIdx.z + offsetZ) % output.getSize(1); // output frame/time
|
||||
int slice = (blockIdx.z + offsetZ) / output.getSize(1); // output slice/feature
|
||||
|
||||
if (oRow < output.getSize(2) && oColumn < output.getSize(3))
|
||||
{
|
||||
int iColumn = oColumn * dW - padW;
|
||||
int iRow = oRow * dH - padH;
|
||||
int iFrame = oFrame * dT - padT;
|
||||
|
||||
int maxColumn = 0;
|
||||
int maxRow = 0;
|
||||
int maxFrame = 0;
|
||||
|
||||
float max = -FLT_MAX;
|
||||
|
||||
for (int frame = 0; frame < kT; ++frame)
|
||||
{
|
||||
if (iFrame + frame < input.getSize(1) && iFrame + frame >= 0)
|
||||
{
|
||||
for (int row = 0; row < kH; ++row)
|
||||
{
|
||||
if (iRow + row < input.getSize(2) && iRow + row >= 0)
|
||||
{
|
||||
for (int column = 0; column < kW; ++column)
|
||||
{
|
||||
if (iColumn + column < input.getSize(3) && iColumn + column >= 0)
|
||||
{
|
||||
float val = input[slice][iFrame + frame][iRow + row][iColumn + column];
|
||||
|
||||
if (max < val)
|
||||
{
|
||||
max = val;
|
||||
maxColumn = column;
|
||||
maxRow = row;
|
||||
maxFrame = frame;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output[slice][oFrame][oRow][oColumn] = max;
|
||||
float *idx = &indices[slice][oFrame][oRow][oColumn];
|
||||
((unsigned char*)(idx))[0] = maxFrame;
|
||||
((unsigned char*)(idx))[1] = maxRow;
|
||||
((unsigned char*)(idx))[2] = maxColumn;
|
||||
((unsigned char*)(idx))[3] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <int KERNEL_WIDTH>
|
||||
__global__ void cuda_VolumetricMaxPooling_updateOutput(
|
||||
THCDeviceTensor<float, 4> input, THCDeviceTensor<float, 4> indices,
|
||||
THCDeviceTensor<float, 4> output,
|
||||
int kT, int kH,
|
||||
int dT, int dH, int dW,
|
||||
int padT, int padH, int padW, int offsetZ)
|
||||
{
|
||||
int oColumn = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int oRow = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
int oFrame = (blockIdx.z + offsetZ) % output.getSize(1); // output frame/time
|
||||
int slice = (blockIdx.z + offsetZ) / output.getSize(1); // output slice/feature
|
||||
|
||||
if (oRow < output.getSize(2) && oColumn < output.getSize(3))
|
||||
{
|
||||
int iColumn = oColumn * dW - padW;
|
||||
int iRow = oRow * dH - padH;
|
||||
int iFrame = oFrame * dT - padT;
|
||||
|
||||
int maxColumn = 0;
|
||||
int maxRow = 0;
|
||||
int maxFrame;
|
||||
|
||||
float max = -FLT_MAX;
|
||||
|
||||
for (int frame = 0; frame < kT; ++frame)
|
||||
{
|
||||
if (iFrame + frame < input.getSize(1) && iFrame + frame >= 0)
|
||||
{
|
||||
for (int row = 0; row < kH; ++row)
|
||||
{
|
||||
if (iRow + row < input.getSize(2) && iRow + row >= 0)
|
||||
{
|
||||
for (int column = 0; column < KERNEL_WIDTH; ++column)
|
||||
{
|
||||
if (iColumn + column < input.getSize(3) && iColumn + column >= 0)
|
||||
{
|
||||
float val = input[slice][iFrame + frame][iRow + row][iColumn + column];
|
||||
|
||||
if (max < val)
|
||||
{
|
||||
max = val;
|
||||
maxColumn = column;
|
||||
maxRow = row;
|
||||
maxFrame = frame;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output[slice][oFrame][oRow][oColumn] = max;
|
||||
float *idx = &indices[slice][oFrame][oRow][oColumn];
|
||||
((unsigned char*)(idx))[0] = maxFrame;
|
||||
((unsigned char*)(idx))[1] = maxRow;
|
||||
((unsigned char*)(idx))[2] = maxColumn;
|
||||
((unsigned char*)(idx))[3] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
#define UPDATE_OUTPUT_KERNEL_WIDTH(KW) case KW: \
|
||||
cuda_VolumetricMaxPooling_updateOutput<KW><<<grid, block, \
|
||||
0, THCState_getCurrentStream(state)>>>( \
|
||||
cudaInput, cudaIndices, cudaOutput, kT, kH, dT, dH, dW, padT, padH, padW, offsetZ); \
|
||||
break
|
||||
|
||||
|
||||
void THNN_CudaVolumetricMaxPooling_updateOutput(
|
||||
THCState *state, THCudaTensor *input, THCudaTensor *output, THCudaTensor *indices,
|
||||
int kT, int kW, int kH,
|
||||
|
|
@ -144,188 +13,10 @@ void THNN_CudaVolumetricMaxPooling_updateOutput(
|
|||
int padT, int padW, int padH,
|
||||
bool ceilMode)
|
||||
{
|
||||
int batchSize;
|
||||
int inputSlices;
|
||||
int inputTime;
|
||||
int inputHeight;
|
||||
int inputWidth;
|
||||
int outputTime;
|
||||
int outputHeight;
|
||||
int outputWidth;
|
||||
THNN_CudaVolumetricDilatedMaxPooling_updateOutput(
|
||||
state, input, output, indices,
|
||||
kT, kW, kH, dT, dW, dH, padT, padW, padH, 1, 1, 1, ceilMode);
|
||||
|
||||
THCUNN_assertSameGPU(state, 3, input, indices, output);
|
||||
|
||||
if (THCudaTensor_nDimension(state, input) == 4)
|
||||
{
|
||||
THArgCheck(
|
||||
THCudaTensor_size(state, input, 1) >= kT &&
|
||||
THCudaTensor_size(state, input, 2) >= kH &&
|
||||
THCudaTensor_size(state, input, 3) >= kW, 2,
|
||||
"input image smaller than kernel size"
|
||||
);
|
||||
|
||||
/* sizes */
|
||||
batchSize = 1;
|
||||
inputSlices = THCudaTensor_size(state, input, 0);
|
||||
inputTime = THCudaTensor_size(state, input, 1);
|
||||
inputHeight = THCudaTensor_size(state, input, 2);
|
||||
inputWidth = THCudaTensor_size(state, input, 3);
|
||||
}
|
||||
else if (THCudaTensor_nDimension(state, input) == 5)
|
||||
{
|
||||
THArgCheck(
|
||||
THCudaTensor_size(state, input, 4) >= kW &&
|
||||
THCudaTensor_size(state, input, 3) >= kH &&
|
||||
THCudaTensor_size(state, input, 2) >= kT, 2,
|
||||
"input image smaller than kernel size"
|
||||
);
|
||||
|
||||
/* sizes */
|
||||
batchSize = THCudaTensor_size(state, input, 0);
|
||||
inputSlices = THCudaTensor_size(state, input, 1);
|
||||
inputTime = THCudaTensor_size(state, input, 2);
|
||||
inputHeight = THCudaTensor_size(state, input, 3);
|
||||
inputWidth = THCudaTensor_size(state, input, 4);
|
||||
}
|
||||
else
|
||||
{
|
||||
THArgCheck(false, 2, "4D or 5D tensor expected");
|
||||
}
|
||||
|
||||
THArgCheck(kT/2 >= padT && kW/2 >= padW && kH/2 >= padH, 2,
|
||||
"pad should be smaller than half of kernel size"
|
||||
);
|
||||
|
||||
if (ceilMode)
|
||||
{
|
||||
outputTime = (int)(ceil((float)(inputTime - kT + 2*padT) / dT)) + 1;
|
||||
outputHeight = (int)(ceil((float)(inputHeight - kH + 2*padH) / dH)) + 1;
|
||||
outputWidth = (int)(ceil((float)(inputWidth - kW + 2*padW) / dW)) + 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
outputTime = (int)(floor((float)(inputTime - kT + 2*padT) / dT)) + 1;
|
||||
outputHeight = (int)(floor((float)(inputHeight - kH + 2*padH) / dH)) + 1;
|
||||
outputWidth = (int)(floor((float)(inputWidth - kW + 2*padW) / dW)) + 1;
|
||||
}
|
||||
|
||||
if (padT || padW || padH)
|
||||
{
|
||||
if ((outputTime - 1)*dT >= inputTime + padT)
|
||||
--outputTime;
|
||||
if ((outputHeight - 1)*dH >= inputHeight + padH)
|
||||
--outputHeight;
|
||||
if ((outputWidth - 1)*dW >= inputWidth + padW)
|
||||
--outputWidth;
|
||||
}
|
||||
|
||||
if (input->nDimension == 4) /* 4D */
|
||||
{
|
||||
/* resize output */
|
||||
THCudaTensor_resize4d(state, output, inputSlices,
|
||||
outputTime, outputHeight, outputWidth);
|
||||
/* indices pack ti,i,j locations for each output point as uchar into
|
||||
each float of the tensor */
|
||||
THCudaTensor_resize4d(state, indices, inputSlices,
|
||||
outputTime, outputHeight, outputWidth);
|
||||
}
|
||||
else
|
||||
{ /* 5D */
|
||||
THCudaTensor_resize5d(state, output, batchSize, inputSlices,
|
||||
outputTime, outputHeight, outputWidth);
|
||||
// Index tensor packs index offsets as uchars into floats
|
||||
THCudaTensor_resize5d(state, indices, batchSize, inputSlices,
|
||||
outputTime, outputHeight, outputWidth);
|
||||
}
|
||||
|
||||
input = THCudaTensor_newContiguous(state, input);
|
||||
|
||||
// Collapse batch and feature dimensions
|
||||
THCDeviceTensor<float, 4> cudaInput;
|
||||
THCDeviceTensor<float, 4> cudaOutput;
|
||||
if (THCudaTensor_nDimension(state, input) == 4)
|
||||
{
|
||||
cudaInput = toDeviceTensor<float, 4>(state, input);
|
||||
cudaOutput = toDeviceTensor<float, 4>(state, output);
|
||||
}
|
||||
else
|
||||
{
|
||||
cudaInput = toDeviceTensor<float, 5>(state, input).downcastOuter<4>();
|
||||
cudaOutput = toDeviceTensor<float, 5>(state, output).downcastOuter<4>();
|
||||
}
|
||||
|
||||
THLongStorage *indicesSize = THLongStorage_newWithSize(4);
|
||||
long indicesSizeRaw[4] = { batchSize * inputSlices,
|
||||
outputTime, outputHeight, outputWidth };
|
||||
THLongStorage_rawCopy(indicesSize, indicesSizeRaw);
|
||||
|
||||
THCudaTensor *indices1 = THCudaTensor_newWithStorage(
|
||||
state, THCudaTensor_storage(state, indices),
|
||||
THCudaTensor_storageOffset(state, indices),
|
||||
indicesSize, NULL);
|
||||
|
||||
THLongStorage_free(indicesSize);
|
||||
|
||||
THCDeviceTensor<float, 4> cudaIndices =
|
||||
toDeviceTensor<float, 4>(state, indices1);
|
||||
|
||||
int totalZ = outputTime * inputSlices * batchSize;
|
||||
int offsetZ = 0;
|
||||
dim3 block(32, 8);
|
||||
|
||||
while (totalZ > 0) {
|
||||
dim3 grid(THCCeilDiv(outputWidth, static_cast<int>(block.x)),
|
||||
THCCeilDiv(outputHeight, static_cast<int>(block.y)),
|
||||
totalZ > 65535 ? 65535 : totalZ);
|
||||
|
||||
switch (kW)
|
||||
{
|
||||
UPDATE_OUTPUT_KERNEL_WIDTH(1);
|
||||
UPDATE_OUTPUT_KERNEL_WIDTH(2);
|
||||
UPDATE_OUTPUT_KERNEL_WIDTH(3);
|
||||
UPDATE_OUTPUT_KERNEL_WIDTH(4);
|
||||
UPDATE_OUTPUT_KERNEL_WIDTH(5);
|
||||
UPDATE_OUTPUT_KERNEL_WIDTH(6);
|
||||
UPDATE_OUTPUT_KERNEL_WIDTH(7);
|
||||
default:
|
||||
cuda_VolumetricMaxPooling_updateOutput<<<grid, block,
|
||||
0, THCState_getCurrentStream(state)>>>(
|
||||
cudaInput, cudaIndices, cudaOutput,
|
||||
kT, kH, kW, dT, dH, dW,
|
||||
padT, padH, padW, offsetZ);
|
||||
}
|
||||
THCudaCheck(cudaGetLastError());
|
||||
totalZ -= 65535;
|
||||
offsetZ += 65535;
|
||||
}
|
||||
|
||||
THCudaTensor_free(state, input);
|
||||
THCudaTensor_free(state, indices1);
|
||||
}
|
||||
|
||||
#undef UPDATE_OUTPUT_KERNEL_WIDTH
|
||||
|
||||
__global__ void cuda_VolumetricMaxPooling_updateGradInput(
|
||||
THCDeviceTensor<float, 4> gradOutput,
|
||||
THCDeviceTensor<float, 4> indices,
|
||||
THCDeviceTensor<float, 4> gradInput,
|
||||
int dT, int dH, int dW,
|
||||
int padT, int padH, int padW, int offsetZ)
|
||||
{
|
||||
int oColumn = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int oRow = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
int oFrame = (blockIdx.z + offsetZ) % gradOutput.getSize(1); // output frame/time
|
||||
int slice = (blockIdx.z + offsetZ) / gradOutput.getSize(1); // output slice/feature
|
||||
|
||||
if (oRow < gradOutput.getSize(2) && oColumn < gradOutput.getSize(3))
|
||||
{
|
||||
float *idx = &indices[slice][oFrame][oRow][oColumn];
|
||||
int iFrame = ((unsigned char*)(idx))[0] + oFrame * dT - padT;
|
||||
int iRow = ((unsigned char*)(idx))[1] + oRow * dH - padH;
|
||||
int iColumn = ((unsigned char*)(idx))[2] + oColumn * dW - padW;
|
||||
atomicAdd(&gradInput[slice][iFrame][iRow][iColumn],
|
||||
gradOutput[slice][oFrame][oRow][oColumn]);
|
||||
}
|
||||
}
|
||||
|
||||
void THNN_CudaVolumetricMaxPooling_updateGradInput(
|
||||
|
|
@ -334,90 +25,8 @@ void THNN_CudaVolumetricMaxPooling_updateGradInput(
|
|||
int dT, int dW, int dH,
|
||||
int padT, int padW, int padH)
|
||||
{
|
||||
// Resize and initialize result tensor.
|
||||
THCudaTensor_resizeAs(state, gradInput, input);
|
||||
THCudaTensor_zero(state, gradInput);
|
||||
THNN_CudaVolumetricDilatedMaxPooling_updateGradInput(
|
||||
state, input, gradOutput, gradInput, indices,
|
||||
dT, dW, dH, padT, padW, padH, 1, 1, 1);
|
||||
|
||||
int batchSize;
|
||||
int inputSlices;
|
||||
|
||||
int outputTime;
|
||||
int outputHeight;
|
||||
int outputWidth;
|
||||
|
||||
THCUNN_assertSameGPU(state, 4, input, indices, gradOutput, gradInput);
|
||||
|
||||
if (THCudaTensor_nDimension(state, input) == 4) /* 4D */
|
||||
{
|
||||
batchSize = 1;
|
||||
inputSlices = THCudaTensor_size(state, input, 0);
|
||||
|
||||
outputTime = THCudaTensor_size(state, gradOutput, 1);
|
||||
outputHeight = THCudaTensor_size(state, gradOutput, 2);
|
||||
outputWidth = THCudaTensor_size(state, gradOutput, 3);
|
||||
}
|
||||
else
|
||||
{
|
||||
batchSize = THCudaTensor_size(state, input, 0);
|
||||
inputSlices = THCudaTensor_size(state, input, 1);
|
||||
|
||||
outputTime = THCudaTensor_size(state, gradOutput, 2);
|
||||
outputHeight = THCudaTensor_size(state, gradOutput, 3);
|
||||
outputWidth = THCudaTensor_size(state, gradOutput, 4);
|
||||
}
|
||||
|
||||
gradOutput = THCudaTensor_newContiguous(state, gradOutput);
|
||||
|
||||
// Collapse batch and feature dimensions
|
||||
THCDeviceTensor<float, 4> cudaGradInput;
|
||||
THCDeviceTensor<float, 4> cudaGradOutput;
|
||||
if (THCudaTensor_nDimension(state, input) == 4)
|
||||
{
|
||||
cudaGradInput = toDeviceTensor<float, 4>(state, gradInput);
|
||||
cudaGradOutput = toDeviceTensor<float, 4>(state, gradOutput);
|
||||
}
|
||||
else
|
||||
{
|
||||
cudaGradInput =
|
||||
toDeviceTensor<float, 5>(state, gradInput).downcastOuter<4>();
|
||||
cudaGradOutput =
|
||||
toDeviceTensor<float, 5>(state, gradOutput).downcastOuter<4>();
|
||||
}
|
||||
|
||||
THLongStorage *indicesSize = THLongStorage_newWithSize(4);
|
||||
long indicesSizeRaw[4] = { batchSize * inputSlices,
|
||||
outputTime, outputHeight, outputWidth };
|
||||
THLongStorage_rawCopy(indicesSize, indicesSizeRaw);
|
||||
THCudaTensor *indices1 = THCudaTensor_newWithStorage(
|
||||
state, THCudaTensor_storage(state, indices),
|
||||
THCudaTensor_storageOffset(state, indices), indicesSize, NULL);
|
||||
THLongStorage_free(indicesSize);
|
||||
|
||||
THCDeviceTensor<float, 4> cudaIndices =
|
||||
toDeviceTensor<float, 4>(state, indices1);
|
||||
|
||||
int totalZ = outputTime * inputSlices * batchSize;
|
||||
int offsetZ = 0;
|
||||
dim3 block(32, 8);
|
||||
|
||||
while (totalZ > 0) {
|
||||
dim3 grid(THCCeilDiv(outputWidth, static_cast<int>(block.x)),
|
||||
THCCeilDiv(outputHeight, static_cast<int>(block.y)),
|
||||
totalZ > 65535 ? 65535 : totalZ);
|
||||
|
||||
cuda_VolumetricMaxPooling_updateGradInput<<<grid, block,
|
||||
0, THCState_getCurrentStream(state)>>>(
|
||||
cudaGradOutput,
|
||||
cudaIndices,
|
||||
cudaGradInput,
|
||||
dT, dH, dW,
|
||||
padT, padH, padW, offsetZ);
|
||||
THCudaCheck(cudaGetLastError());
|
||||
totalZ -= 65535;
|
||||
offsetZ += 65535;
|
||||
}
|
||||
|
||||
// cleanup
|
||||
THCudaTensor_free(state, gradOutput);
|
||||
THCudaTensor_free(state, indices1);
|
||||
}
|
||||
|
|
|
|||
196
torch/lib/THCUNN/cmake/select_compute_arch.cmake
Normal file
196
torch/lib/THCUNN/cmake/select_compute_arch.cmake
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
# Synopsis:
|
||||
# CUDA_SELECT_NVCC_ARCH_FLAGS(out_variable [target_CUDA_architectures])
|
||||
# -- Selects GPU arch flags for nvcc based on target_CUDA_architectures
|
||||
# target_CUDA_architectures : Auto | Common | All | LIST(ARCH_AND_PTX ...)
|
||||
# - "Auto" detects local machine GPU compute arch at runtime.
|
||||
# - "Common" and "All" cover common and entire subsets of architectures
|
||||
# ARCH_AND_PTX : NAME | NUM.NUM | NUM.NUM(NUM.NUM) | NUM.NUM+PTX
|
||||
# NAME: Fermi Kepler Maxwell Kepler+Tegra Kepler+Tesla Maxwell+Tegra Pascal
|
||||
# NUM: Any number. Only those pairs are currently accepted by NVCC though:
|
||||
# 2.0 2.1 3.0 3.2 3.5 3.7 5.0 5.2 5.3 6.0 6.2
|
||||
# Returns LIST of flags to be added to CUDA_NVCC_FLAGS in ${out_variable}
|
||||
# Additionally, sets ${out_variable}_readable to the resulting numeric list
|
||||
# Example:
|
||||
# CUDA_SELECT_NVCC_ARCH_FLAGS(ARCH_FLAGS 3.0 3.5+PTX 5.2(5.0) Maxwell)
|
||||
# LIST(APPEND CUDA_NVCC_FLAGS ${ARCH_FLAGS})
|
||||
#
|
||||
# More info on CUDA architectures: https://en.wikipedia.org/wiki/CUDA
|
||||
#
|
||||
|
||||
# This list will be used for CUDA_ARCH_NAME = All option
|
||||
set(CUDA_KNOWN_GPU_ARCHITECTURES "Fermi" "Kepler" "Maxwell")
|
||||
|
||||
# This list will be used for CUDA_ARCH_NAME = Common option (enabled by default)
|
||||
set(CUDA_COMMON_GPU_ARCHITECTURES "3.0" "3.5" "5.0")
|
||||
|
||||
if (CUDA_VERSION VERSION_GREATER "6.5")
|
||||
list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Kepler+Tegra" "Kepler+Tesla" "Maxwell+Tegra")
|
||||
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "5.2")
|
||||
endif ()
|
||||
|
||||
if (CUDA_VERSION VERSION_GREATER "7.5")
|
||||
list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Pascal")
|
||||
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "6.0" "6.1" "6.1+PTX")
|
||||
else()
|
||||
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "5.2+PTX")
|
||||
endif ()
|
||||
|
||||
|
||||
|
||||
################################################################################################
|
||||
# A function for automatic detection of GPUs installed (if autodetection is enabled)
|
||||
# Usage:
|
||||
# CUDA_DETECT_INSTALLED_GPUS(OUT_VARIABLE)
|
||||
#
|
||||
function(CUDA_DETECT_INSTALLED_GPUS OUT_VARIABLE)
|
||||
if(NOT CUDA_GPU_DETECT_OUTPUT)
|
||||
set(cufile ${PROJECT_BINARY_DIR}/detect_cuda_archs.cu)
|
||||
|
||||
file(WRITE ${cufile} ""
|
||||
"#include <cstdio>\n"
|
||||
"int main()\n"
|
||||
"{\n"
|
||||
" int count = 0;\n"
|
||||
" if (cudaSuccess != cudaGetDeviceCount(&count)) return -1;\n"
|
||||
" if (count == 0) return -1;\n"
|
||||
" for (int device = 0; device < count; ++device)\n"
|
||||
" {\n"
|
||||
" cudaDeviceProp prop;\n"
|
||||
" if (cudaSuccess == cudaGetDeviceProperties(&prop, device))\n"
|
||||
" std::printf(\"%d.%d \", prop.major, prop.minor);\n"
|
||||
" }\n"
|
||||
" return 0;\n"
|
||||
"}\n")
|
||||
|
||||
execute_process(COMMAND "${CUDA_NVCC_EXECUTABLE}" "--run" "${cufile}"
|
||||
"-ccbin" ${CMAKE_CXX_COMPILER}
|
||||
WORKING_DIRECTORY "${PROJECT_BINARY_DIR}/CMakeFiles/"
|
||||
RESULT_VARIABLE nvcc_res OUTPUT_VARIABLE nvcc_out
|
||||
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
|
||||
if(nvcc_res EQUAL 0)
|
||||
string(REPLACE "2.1" "2.1(2.0)" nvcc_out "${nvcc_out}")
|
||||
set(CUDA_GPU_DETECT_OUTPUT ${nvcc_out} CACHE INTERNAL "Returned GPU architetures from detect_gpus tool" FORCE)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(NOT CUDA_GPU_DETECT_OUTPUT)
|
||||
message(STATUS "Automatic GPU detection failed. Building for common architectures.")
|
||||
set(${OUT_VARIABLE} ${CUDA_COMMON_GPU_ARCHITECTURES} PARENT_SCOPE)
|
||||
else()
|
||||
set(${OUT_VARIABLE} ${CUDA_GPU_DETECT_OUTPUT} PARENT_SCOPE)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
|
||||
################################################################################################
|
||||
# Function for selecting GPU arch flags for nvcc based on CUDA architectures from parameter list
|
||||
# Usage:
|
||||
# SELECT_NVCC_ARCH_FLAGS(out_variable [list of CUDA compute archs])
|
||||
function(CUDA_SELECT_NVCC_ARCH_FLAGS out_variable)
|
||||
set(CUDA_ARCH_LIST "${ARGN}")
|
||||
|
||||
if("X${CUDA_ARCH_LIST}" STREQUAL "X" )
|
||||
set(CUDA_ARCH_LIST "Auto")
|
||||
endif()
|
||||
|
||||
set(cuda_arch_bin)
|
||||
set(cuda_arch_ptx)
|
||||
|
||||
if("${CUDA_ARCH_LIST}" STREQUAL "All")
|
||||
set(CUDA_ARCH_LIST ${CUDA_KNOWN_GPU_ARCHITECTURES})
|
||||
elseif("${CUDA_ARCH_LIST}" STREQUAL "Common")
|
||||
set(CUDA_ARCH_LIST ${CUDA_COMMON_GPU_ARCHITECTURES})
|
||||
elseif("${CUDA_ARCH_LIST}" STREQUAL "Auto")
|
||||
CUDA_DETECT_INSTALLED_GPUS(CUDA_ARCH_LIST)
|
||||
message(STATUS "Autodetected CUDA architecture(s): ${CUDA_ARCH_LIST}")
|
||||
endif()
|
||||
|
||||
# Now process the list and look for names
|
||||
string(REGEX REPLACE "[ \t]+" ";" CUDA_ARCH_LIST "${CUDA_ARCH_LIST}")
|
||||
list(REMOVE_DUPLICATES CUDA_ARCH_LIST)
|
||||
foreach(arch_name ${CUDA_ARCH_LIST})
|
||||
set(arch_bin)
|
||||
set(add_ptx FALSE)
|
||||
# Check to see if we are compiling PTX
|
||||
if(arch_name MATCHES "(.*)\\+PTX$")
|
||||
set(add_ptx TRUE)
|
||||
set(arch_name ${CMAKE_MATCH_1})
|
||||
endif()
|
||||
if(arch_name MATCHES "(^[0-9]\\.[0-9](\\([0-9]\\.[0-9]\\))?)$")
|
||||
set(arch_bin ${CMAKE_MATCH_1})
|
||||
set(arch_ptx ${arch_bin})
|
||||
else()
|
||||
# Look for it in our list of known architectures
|
||||
if(${arch_name} STREQUAL "Fermi")
|
||||
set(arch_bin "2.0 2.1(2.0)")
|
||||
elseif(${arch_name} STREQUAL "Kepler+Tegra")
|
||||
set(arch_bin 3.2)
|
||||
elseif(${arch_name} STREQUAL "Kepler+Tesla")
|
||||
set(arch_bin 3.7)
|
||||
elseif(${arch_name} STREQUAL "Kepler")
|
||||
set(arch_bin 3.0 3.5)
|
||||
set(arch_ptx 3.5)
|
||||
elseif(${arch_name} STREQUAL "Maxwell+Tegra")
|
||||
set(arch_bin 5.3)
|
||||
elseif(${arch_name} STREQUAL "Maxwell")
|
||||
set(arch_bin 5.0 5.2)
|
||||
set(arch_ptx 5.2)
|
||||
elseif(${arch_name} STREQUAL "Pascal")
|
||||
set(arch_bin 6.0 6.1)
|
||||
set(arch_ptx 6.1)
|
||||
else()
|
||||
message(SEND_ERROR "Unknown CUDA Architecture Name ${arch_name} in CUDA_SELECT_NVCC_ARCH_FLAGS")
|
||||
endif()
|
||||
endif()
|
||||
if(NOT arch_bin)
|
||||
message(SEND_ERROR "arch_bin wasn't set for some reason")
|
||||
endif()
|
||||
list(APPEND cuda_arch_bin ${arch_bin})
|
||||
if(add_ptx)
|
||||
if (NOT arch_ptx)
|
||||
set(arch_ptx ${arch_bin})
|
||||
endif()
|
||||
list(APPEND cuda_arch_ptx ${arch_ptx})
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
# remove dots and convert to lists
|
||||
string(REGEX REPLACE "\\." "" cuda_arch_bin "${cuda_arch_bin}")
|
||||
string(REGEX REPLACE "\\." "" cuda_arch_ptx "${cuda_arch_ptx}")
|
||||
string(REGEX MATCHALL "[0-9()]+" cuda_arch_bin "${cuda_arch_bin}")
|
||||
string(REGEX MATCHALL "[0-9]+" cuda_arch_ptx "${cuda_arch_ptx}")
|
||||
|
||||
if(cuda_arch_bin)
|
||||
list(REMOVE_DUPLICATES cuda_arch_bin)
|
||||
endif()
|
||||
if(cuda_arch_ptx)
|
||||
list(REMOVE_DUPLICATES cuda_arch_ptx)
|
||||
endif()
|
||||
|
||||
set(nvcc_flags "")
|
||||
set(nvcc_archs_readable "")
|
||||
|
||||
# Tell NVCC to add binaries for the specified GPUs
|
||||
foreach(arch ${cuda_arch_bin})
|
||||
if(arch MATCHES "([0-9]+)\\(([0-9]+)\\)")
|
||||
# User explicitly specified ARCH for the concrete CODE
|
||||
list(APPEND nvcc_flags -gencode arch=compute_${CMAKE_MATCH_2},code=sm_${CMAKE_MATCH_1})
|
||||
list(APPEND nvcc_archs_readable sm_${CMAKE_MATCH_1})
|
||||
else()
|
||||
# User didn't explicitly specify ARCH for the concrete CODE, we assume ARCH=CODE
|
||||
list(APPEND nvcc_flags -gencode arch=compute_${arch},code=sm_${arch})
|
||||
list(APPEND nvcc_archs_readable sm_${arch})
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
# Tell NVCC to add PTX intermediate code for the specified architectures
|
||||
foreach(arch ${cuda_arch_ptx})
|
||||
list(APPEND nvcc_flags -gencode arch=compute_${arch},code=compute_${arch})
|
||||
list(APPEND nvcc_archs_readable compute_${arch})
|
||||
endforeach()
|
||||
|
||||
string(REPLACE ";" " " nvcc_archs_readable "${nvcc_archs_readable}")
|
||||
set(${out_variable} ${nvcc_flags} PARENT_SCOPE)
|
||||
set(${out_variable}_readable ${nvcc_archs_readable} PARENT_SCOPE)
|
||||
endfunction()
|
||||
Loading…
Reference in New Issue
Block a user