mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
EnsureDense/SparseToDense for CUDA
Summary: Make CUDA version of SparseToDense, register EnsureDense (which is trivial) on CUDA. Need to use atomics because indices can be duplicated. We can later add an option to inform if the indices are unique, and use faster path then. Reviewed By: jhcross Differential Revision: D5750893 fbshipit-source-id: 005d1675b127a571aac8474fca62d9633f0c7bff
This commit is contained in:
parent
b2bd9ef15a
commit
bb08f261f1
|
|
@ -4,7 +4,6 @@
|
|||
|
||||
namespace caffe2 {
|
||||
|
||||
namespace {
|
||||
REGISTER_CPU_OPERATOR(SparseToDense, SparseToDenseOp<CPUContext>);
|
||||
|
||||
OPERATOR_SCHEMA(SparseToDense)
|
||||
|
|
@ -47,6 +46,8 @@ output[j, ...] = 0 if j not in indices
|
|||
"len(mask)] + shape(default_value)` (if `lengths` is not provided the "
|
||||
"first dimension is omitted)");
|
||||
|
||||
|
||||
namespace {
|
||||
class GetSparseToDenseGradient : public GradientMakerBase {
|
||||
using GradientMakerBase::GradientMakerBase;
|
||||
vector<OperatorDef> GetGradientDefs() override {
|
||||
|
|
@ -56,5 +57,5 @@ class GetSparseToDenseGradient : public GradientMakerBase {
|
|||
};
|
||||
|
||||
REGISTER_GRADIENT(SparseToDense, GetSparseToDenseGradient);
|
||||
} // namespace
|
||||
}
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
77
caffe2/operators/sparse_to_dense_op.cu
Normal file
77
caffe2/operators/sparse_to_dense_op.cu
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
#include "sparse_to_dense_op.h"
|
||||
|
||||
#include "caffe2/core/common_gpu.h"
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
template <typename TInd, typename TData>
|
||||
__global__ void SparseToDenseKernel(
|
||||
size_t N, TIndex block_nitems, const TInd* indices, const TData* vals, TData* dst) {
|
||||
CUDA_1D_KERNEL_LOOP(i, N) {
|
||||
int idx = indices[i / block_nitems];
|
||||
int dst_idx = block_nitems * idx + i % block_nitems;
|
||||
atomicAdd(&dst[dst_idx], vals[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
bool SparseToDenseOp<CUDAContext>::RunOnDevice() {
|
||||
return DispatchHelper<TensorTypes<int32_t>>::call(
|
||||
this, Input(INDICES));
|
||||
}
|
||||
|
||||
template <>
|
||||
template <typename TInd>
|
||||
bool SparseToDenseOp<CUDAContext>::DoRunWithType() {
|
||||
return DispatchHelper<
|
||||
TensorTypes2<
|
||||
float,
|
||||
int32_t>,
|
||||
TInd>::call(this, Input(VALUES));
|
||||
}
|
||||
|
||||
template <>
|
||||
template <typename TInd, typename TData>
|
||||
bool SparseToDenseOp<CUDAContext>::DoRunWithType2() {
|
||||
auto& sparse_indices = Input(INDICES);
|
||||
CAFFE_ENFORCE_EQ(sparse_indices.ndim(), 1);
|
||||
auto& sparse_values = Input(VALUES);
|
||||
CAFFE_ENFORCE_GE(sparse_values.ndim(), 1);
|
||||
CAFFE_ENFORCE_EQ(sparse_indices.size(), sparse_values.dim(0));
|
||||
|
||||
const TInd* sparse_indices_vec = sparse_indices.template data<TInd>();
|
||||
const int32_t sparse_indices_len = sparse_indices.dim32(0);
|
||||
const int output_first_dim =
|
||||
GetOutputFirstDim(sparse_indices_vec, sparse_indices_len);
|
||||
|
||||
auto shape = sparse_values.dims();
|
||||
shape[0] = output_first_dim;
|
||||
auto* output = Output(0);
|
||||
output->Resize(shape);
|
||||
|
||||
TData* output_data = output->template mutable_data<TData>();
|
||||
math::Set<TData>(output->size(), TData(0), output_data, &context_);
|
||||
|
||||
const auto block_nitems = sparse_values.size_from_dim(1);
|
||||
const TData* sparse_values_vec = sparse_values.template data<TData>();
|
||||
|
||||
size_t N = block_nitems * sparse_indices_len;
|
||||
CAFFE_ENFORCE_EQ(output->size(), output_first_dim * block_nitems);
|
||||
SparseToDenseKernel<TInd, TData><<<
|
||||
CAFFE_GET_BLOCKS(N), CAFFE_CUDA_NUM_THREADS, 0,
|
||||
context_.cuda_stream()>>>(
|
||||
N,
|
||||
block_nitems,
|
||||
sparse_indices_vec,
|
||||
sparse_values_vec,
|
||||
output_data
|
||||
);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
REGISTER_CUDA_OPERATOR(SparseToDense, SparseToDenseOp<CUDAContext>);
|
||||
|
||||
} // namespace caffe2
|
||||
|
|
@ -40,9 +40,15 @@ class SparseToDenseOp final : public Operator<Context> {
|
|||
if (sparse_indices_len <= 0) {
|
||||
return 0;
|
||||
}
|
||||
return 1 +
|
||||
*std::max_element(
|
||||
sparse_indices_vec, sparse_indices_vec + sparse_indices_len);
|
||||
|
||||
// Awkward way to get the max element to make it work with both CUDA
|
||||
// and CPU.
|
||||
max_element_.Resize(1);
|
||||
TInd* max_element_ptr = max_element_.template mutable_data<TInd>();
|
||||
math::ReduceMax<TInd>(sparse_indices_len, sparse_indices_vec, max_element_ptr,
|
||||
&scratch_, &context_);
|
||||
max_element_host_.CopyFrom(max_element_);
|
||||
return 1 + max_element_host_.template data<TInd>()[0];
|
||||
}
|
||||
|
||||
template <typename TInd>
|
||||
|
|
@ -104,6 +110,9 @@ class SparseToDenseOp final : public Operator<Context> {
|
|||
|
||||
private:
|
||||
int output_first_dim_;
|
||||
Tensor<Context> scratch_;
|
||||
Tensor<CPUContext> max_element_host_;
|
||||
Tensor<Context> max_element_;
|
||||
|
||||
INPUT_TAGS(INDICES, VALUES, DATA_TO_INFER_DIM);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@
|
|||
namespace caffe2 {
|
||||
CAFFE_KNOWN_TYPE(const float*);
|
||||
|
||||
REGISTER_CUDA_OPERATOR(EnsureDense, EnsureDenseOp<CUDAContext>);
|
||||
|
||||
|
||||
__global__ void NanCheckKernel(int N, const float* X, bool* result) {
|
||||
bool has_nan = false;
|
||||
|
|
|
|||
|
|
@ -2128,11 +2128,19 @@ class TestOperators(hu.HypothesisTestCase):
|
|||
|
||||
@given(inp=_dtypes().flatmap(lambda dt: _tensor_and_indices(
|
||||
elements=st.floats(min_value=0, max_value=1), dtype=dt)),
|
||||
**hu.gcs_cpu_only)
|
||||
**hu.gcs)
|
||||
def test_sparse_to_dense(self, inp, gc, dc):
|
||||
first_dim, X, I = inp
|
||||
if X.dtype != np.dtype('float32') and gc.device_type == 1:
|
||||
# Cuda only support 32 bit float
|
||||
print("Bailout {}".format(X.dtype))
|
||||
return
|
||||
if gc.device_type == 1:
|
||||
# Cuda version only support int32
|
||||
I = I.astype(np.int32)
|
||||
|
||||
# values don't matter
|
||||
D = np.zeros((first_dim,) + X.shape[1:])
|
||||
D = np.zeros((first_dim,) + X.shape[1:]).astype(X.dtype)
|
||||
|
||||
op = core.CreateOperator("SparseToDense", ["I", "X", "D"], ["Y"])
|
||||
|
||||
|
|
@ -2143,8 +2151,6 @@ class TestOperators(hu.HypothesisTestCase):
|
|||
return [O]
|
||||
|
||||
self.assertReferenceChecks(gc, op, [I, X, D], sparse_to_dense)
|
||||
self.assertDeviceChecks(dc, op, [I, X, D], [0])
|
||||
|
||||
X = X.astype(np.float32)
|
||||
self.assertGradientChecks(gc, op, [I, X, D], 1, [0])
|
||||
|
||||
|
|
|
|||
|
|
@ -579,6 +579,9 @@ CAFFE2_SPECIALIZED_REDUCEMIN(float)
|
|||
*y = *std::max_element(x, x + N); \
|
||||
}
|
||||
CAFFE2_SPECIALIZED_REDUCEMAX(float)
|
||||
CAFFE2_SPECIALIZED_REDUCEMAX(int32_t)
|
||||
CAFFE2_SPECIALIZED_REDUCEMAX(int64_t)
|
||||
|
||||
#undef CAFFE2_SPECIALIZED_REDUCEMAX
|
||||
|
||||
#define CAFFE2_SPECIALIZED_ROWWISEMAX(T) \
|
||||
|
|
|
|||
|
|
@ -85,6 +85,7 @@ DELEGATE_SINCOS_CUDA_FUNCTION(double)
|
|||
}
|
||||
|
||||
DELEGATE_SIMPLE_CUDA_BINARY_INFIX_FUNCTION(float, Add, +);
|
||||
DELEGATE_SIMPLE_CUDA_BINARY_INFIX_FUNCTION(int32_t, Add, +);
|
||||
DELEGATE_SIMPLE_CUDA_BINARY_INFIX_FUNCTION(float, Sub, -);
|
||||
DELEGATE_SIMPLE_CUDA_BINARY_INFIX_FUNCTION(float, Mul, *);
|
||||
DELEGATE_SIMPLE_CUDA_BINARY_INFIX_FUNCTION(float, Div, /);
|
||||
|
|
@ -144,6 +145,9 @@ DELEGATE_SIMPLE_CUDA_BINARY_PREFIX_FUNCTION(float, ElemwiseMax, fmaxf);
|
|||
|
||||
DELEGATE_REDUCTION_FUNCTION(float, ReduceMin, Min)
|
||||
DELEGATE_REDUCTION_FUNCTION(float, ReduceMax, Max)
|
||||
DELEGATE_REDUCTION_FUNCTION(int32_t, ReduceMax, Max)
|
||||
DELEGATE_REDUCTION_FUNCTION(int64_t, ReduceMax, Max)
|
||||
|
||||
|
||||
#undef DELEGATE_REDUCTION_FUNCTION
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user