diff --git a/aten/src/ATen/Dispatch.h b/aten/src/ATen/Dispatch.h index e0d8a9a525b..9f0c5116617 100644 --- a/aten/src/ATen/Dispatch.h +++ b/aten/src/ATen/Dispatch.h @@ -557,6 +557,25 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} } \ }() +#define AT_DISPATCH_INDEX_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& the_index_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op */ \ + at::ScalarType _it = ::detail::scalar_type(the_index_type); \ + switch (_it) { \ + case at::ScalarType::Int: { \ + using index_t = int32_t; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::Long: { \ + using index_t = int64_t; \ + return __VA_ARGS__(); \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(_it), "'"); \ + } \ + }() + // ---------------------------------------------------------------------------- // DEPRECATED MACROS, DON'T USE THESE // ---------------------------------------------------------------------------- diff --git a/aten/src/ATen/native/Embedding.cpp b/aten/src/ATen/native/Embedding.cpp index 6589a33ed2f..bf74e8b356c 100644 --- a/aten/src/ATen/native/Embedding.cpp +++ b/aten/src/ATen/native/Embedding.cpp @@ -15,7 +15,7 @@ Tensor embedding(const Tensor & weight, const Tensor & indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) { TORCH_CHECK(weight.dim() >= 1, "'weight' must be at least 1-D"); auto indices_arg = TensorArg(indices, "indices", 1); - checkScalarType("embedding", indices_arg, kLong); + checkScalarTypes("embedding", indices_arg, {kLong, kInt}); auto zerofill_padding = [&](Tensor& embedding) { if (padding_idx >= 0) { @@ -57,7 +57,7 @@ Tensor embedding_sparse_backward( int64_t padding_idx, bool scale_grad_by_freq) { auto indices_arg = TensorArg(indices_, "indices", 2); - checkScalarType("embedding_backward", indices_arg, kLong); + checkScalarTypes("embedding_backward", indices_arg, {kLong, kInt}); // TODO: implement scale_grad_by_freq if (scale_grad_by_freq) { @@ -79,14 +79,14 @@ Tensor embedding_sparse_backward( // check if all our grad come from padding_idx if (grad.numel() == 0) { - return at::_sparse_coo_tensor_unsafe(at::empty({1, 0}, indices_.options()), + return at::_sparse_coo_tensor_unsafe(at::empty({1, 0}, indices_.options().dtype(kLong)), at::empty({0, num_features}, dense_options), weight_size); } auto index = indices.reshape({1, -1}); auto values = grad.reshape({-1, num_features}); - return at::_sparse_coo_tensor_unsafe(index, values, weight_size); + return at::_sparse_coo_tensor_unsafe(index.to(kLong), values, weight_size); } Tensor embedding_dense_backward_cpu( @@ -94,50 +94,48 @@ Tensor embedding_dense_backward_cpu( int64_t padding_idx, bool scale_grad_by_freq) { auto indices_arg = TensorArg(indices, "indices", 2); - checkScalarType("embedding_backward", indices_arg, kLong); + checkScalarTypes("embedding_backward", indices_arg, {kLong, kInt}); - auto indices_contig = indices.contiguous(); - auto indices_data = indices_contig.data_ptr(); - int64_t numel = indices.numel(); - - std::unique_ptr counts; - if (scale_grad_by_freq) { - counts.reset(new int64_t[num_weights]); - for (int i = 0; i < numel; i++) { - counts[indices_data[i]] = 0; - } - for (int i = 0; i < numel; i++) { - counts[indices_data[i]]++; - } - } - - auto grad = grad_.contiguous().view({numel, grad_.size(-1)}); auto grad_weight = at::zeros({num_weights, grad_.size(-1)}, grad_.options()); + auto indices_contig = indices.contiguous(); + int64_t numel = indices.numel(); + auto grad = grad_.contiguous().view({numel, grad_.size(-1)}); - auto parallel_section = [&](int64_t start, int64_t end) { - for (int64_t i = 0; i < numel; i++) { - if (indices_data[i] != padding_idx) { - int64_t k = indices_data[i]; - if (k >= start && k < end) { - double scale = 1.0; - if (scale_grad_by_freq) { - scale /= counts[k]; - } - grad_weight[k].add_(grad[i], scale); - } + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cpu", [&] () { + auto indices_data = indices_contig.data_ptr(); + + std::unique_ptr counts; + if (scale_grad_by_freq) { + counts.reset(new index_t[num_weights]); + for (int i = 0; i < numel; i++) { + counts[indices_data[i]] = 0; + } + for (int i = 0; i < numel; i++) { + counts[indices_data[i]]++; } } - }; - if (numel > 1000) { - // The strategy is to parallelize over sections of the vocabulary, so that - // thread 1 handles updates to gradWeight[0..nVocab/nThreads]. Every thread - // has to traverse the entire input, but the dominating factor is the axpy - // BLAS call. - at::parallel_for(0, num_weights, 0, parallel_section); - } else { - parallel_section(0, num_weights); - } + auto parallel_section = [&](index_t start, index_t end) { + for (int64_t i = 0; i < numel; i++) { + if (indices_data[i] != padding_idx) { + index_t k = indices_data[i]; + if (k >= start && k < end) { + double scale = 1.0; + if (scale_grad_by_freq) { + scale /= counts[k]; + } + grad_weight[k].add_(grad[i], scale); + } + } + } + }; + + if (numel > 1000) { + at::parallel_for(0, num_weights, 0, parallel_section); + } else { + parallel_section(0, num_weights); + } + }); return grad_weight; } @@ -147,28 +145,30 @@ Tensor & embedding_renorm_cpu_( auto self_arg = TensorArg(self, "self", 1); auto indices_arg = TensorArg(indices, "indices", 2); checkDim("embedding_renorm_", self_arg, 2); - checkScalarType("embedding_renorm_", indices_arg, kLong); + checkScalarTypes("embedding_renorm_", indices_arg, {kLong, kInt}); auto indices_contig = indices.contiguous(); - auto num_indices = indices.numel(); - auto data_ptr = indices_contig.data_ptr(); - auto sorted_indices = std::vector(data_ptr, data_ptr + num_indices); - std::sort(sorted_indices.begin(), sorted_indices.end(), std::less()); - // Note that we cannot use at::parallel_for here because we perform operations on - // Tensor inside the loop. See github.com/pytorch/pytorch/issues/28370 for more details. - for (auto i = 0; i < num_indices; i++) { - if (i > 0 && sorted_indices[i] == sorted_indices[i - 1]) { - continue; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_renorm_cpu_", [&]() { + auto data_ptr = indices_contig.data_ptr(); + auto sorted_indices = std::vector(data_ptr, data_ptr + num_indices); + std::sort(sorted_indices.begin(), sorted_indices.end()); + + // Note that we cannot use at::parallel_for here because we perform operations on + // Tensor inside the loop. See github.com/pytorch/pytorch/issues/28370 for more details. + for (auto i = 0; i < num_indices; i++) { + if (i > 0 && sorted_indices[i] == sorted_indices[i - 1]) { + continue; + } + auto row = self[sorted_indices[i]]; + auto norm = row.norm(norm_type).item(); + if (norm > max_norm) { + auto scale = max_norm / (norm + 1e-7); + row *= scale; + } } - auto row = self[sorted_indices[i]]; - auto norm = row.norm(norm_type).item(); - if (norm > max_norm) { - auto scale = max_norm / (norm + 1e-7); - row *= scale; - } - } + }); return self; } diff --git a/aten/src/ATen/native/EmbeddingBag.cpp b/aten/src/ATen/native/EmbeddingBag.cpp index a0b2a37ed6d..ef318285ed4 100644 --- a/aten/src/ATen/native/EmbeddingBag.cpp +++ b/aten/src/ATen/native/EmbeddingBag.cpp @@ -32,11 +32,11 @@ namespace native { template scalar_t dot_impl(int64_t n, scalar_t *x, int64_t incx, scalar_t *y, int64_t incy); -static void make_offset2bag(const Tensor &offsets, const Tensor &indices, Tensor& offset2bag) { +static void make_offset2bag(const Tensor &offsets, Tensor& offset2bag) { offset2bag.index_add_( 0, offsets, at::ones_like(offsets, LEGACY_CONTIGUOUS_MEMORY_FORMAT)); // offset2bag = [1 0 1 0 1] offset2bag[0] -= 1; // offset2bag = [0 0 1 0 1] - offset2bag = offset2bag.cumsum(0); // offset2bag = [0 0 1 1 2] + offset2bag = offset2bag.cumsum(0, offset2bag.scalar_type()); // offset2bag = [0 0 1 1 2] } namespace { @@ -52,18 +52,19 @@ bool isFastPathIndexSelectScale(const Tensor& src, const Tensor& scale, Tensor& // This function combines index_select (using select_indices as the index) and // index_add (using add_indices as the index), without creating an intermediary // tensor to hold the selected embeddings -template -void index_select_add(const Tensor &select_indices, +template +typename std::enable_if::value, void>::type +index_select_add(const Tensor &select_indices, const Tensor &add_indices, const Tensor &src, Tensor &output, const Tensor& /*offsets*/, bool /*include_last_offset*/) { AT_ASSERT(select_indices.numel() == add_indices.numel()); - auto* add_indices_data = add_indices.data_ptr(); - auto* select_indices_data = select_indices.data_ptr(); - auto* src_data = src.data_ptr(); - auto* output_data = output.data_ptr(); + auto* add_indices_data = add_indices.data_ptr(); + auto* select_indices_data = select_indices.data_ptr(); + auto* src_data = src.data_ptr(); + auto* output_data = output.data_ptr(); auto numel = add_indices.numel(); int64_t ddim = src.size(1); auto src_stride0 = src.stride(0); @@ -72,29 +73,30 @@ void index_select_add(const Tensor &select_indices, auto output_stride1 = output.stride(1); for (int64_t i = 0; i < numel; i++) { - THBlas_axpy(ddim, 1, + THBlas_axpy(ddim, 1, src_data + src_stride0 * select_indices_data[i], src_stride1, output_data + output_stride0 * add_indices_data[i], output_stride1); } } -template<> -void index_select_add(const Tensor &select_indices, +template +typename std::enable_if::value, void>::type +index_select_add(const Tensor &select_indices, const Tensor &add_indices, const Tensor &src, Tensor &output, const Tensor& offsets, bool include_last_offset) { int64_t ddim = src.size(1); - auto* select_indices_data = select_indices.data_ptr(); + auto* select_indices_data = select_indices.data_ptr(); auto* output_data = output.data_ptr(); if (isFastPathIndexSelect(src, output)) { auto src_contig = src.contiguous(); auto* src_data = src_contig.data_ptr(); int64_t output_size = offsets.numel() - 1; - auto* offsets_data = offsets.data_ptr(); - std::vector offsets_include_last; + auto* offsets_data = offsets.data_ptr(); + std::vector offsets_include_last; if (include_last_offset) { output_size = offsets.numel() - 1; @@ -103,15 +105,15 @@ void index_select_add(const Tensor &select_indices, offsets_include_last.resize(offsets.numel() + 1); std::memcpy( offsets_include_last.data(), - offsets.data_ptr(), - sizeof(int64_t) * offsets.numel()); + offsets.data_ptr(), + sizeof(index_t) * offsets.numel()); offsets_include_last[offsets.numel()] = select_indices.numel(); offsets_data = offsets_include_last.data(); } #ifdef USE_FBGEMM - auto kernel_fp32_i64 = - fbgemm::GenerateEmbeddingSpMDM( + auto kernel_fp32_index_t = + fbgemm::GenerateEmbeddingSpMDM( /* block_size */ddim, /* has_weight */false, /* normalize_by_lengths */false, @@ -121,9 +123,9 @@ void index_select_add(const Tensor &select_indices, ); #endif at::parallel_for( - 0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) { + 0, output_size, 1, [&](index_t start_idx, index_t end_idx) { #ifdef USE_FBGEMM - kernel_fp32_i64( + kernel_fp32_index_t( /* output_size */end_idx - start_idx, /* index_size */offsets_data[end_idx] - offsets_data[start_idx], /* data_size */src.size(0), @@ -150,7 +152,7 @@ void index_select_add(const Tensor &select_indices, } else { AT_ASSERT(select_indices.numel() == add_indices.numel()); auto* src_data = src.data_ptr(); - auto* add_indices_data = add_indices.data_ptr(); + auto* add_indices_data = add_indices.data_ptr(); auto src_stride0 = src.stride(0); auto src_stride1 = src.stride(1); auto output_stride0 = output.stride(0); @@ -172,8 +174,9 @@ void index_select_add(const Tensor &select_indices, // index_select (using select_indices as the index) // mul (scaling by per_sample_weights) // index_add (using add_indices as the index) -template -static void index_select_scale_add(const Tensor &select_indices, +template +static typename std::enable_if::value, void>::type +index_select_scale_add(const Tensor &select_indices, const Tensor &add_indices, const Tensor &scale, const Tensor &src, @@ -181,10 +184,10 @@ static void index_select_scale_add(const Tensor &select_indices, const Tensor& /*offsets*/, bool /*include_last_offset*/) { AT_ASSERT(select_indices.numel() == add_indices.numel()); - auto* add_indices_data = add_indices.data_ptr(); - auto* select_indices_data = select_indices.data_ptr(); - auto* src_data = src.data_ptr(); - auto* output_data = output.data_ptr(); + auto* add_indices_data = add_indices.data_ptr(); + auto* select_indices_data = select_indices.data_ptr(); + auto* src_data = src.data_ptr(); + auto* output_data = output.data_ptr(); auto numel = add_indices.numel(); int64_t ddim = src.size(1); auto src_stride0 = src.stride(0); @@ -192,7 +195,7 @@ static void index_select_scale_add(const Tensor &select_indices, auto output_stride0 = output.stride(0); auto output_stride1 = output.stride(1); - auto* scale_data = scale.data_ptr(); + auto* scale_data = scale.data_ptr(); auto scale_stride = scale.stride(0); for (int64_t i = 0; i < numel; i++) { @@ -205,8 +208,9 @@ static void index_select_scale_add(const Tensor &select_indices, } } -template<> -void index_select_scale_add(const Tensor &select_indices, +template +typename std::enable_if::value, void>::type +index_select_scale_add(const Tensor &select_indices, const Tensor &add_indices, const Tensor &scale, const Tensor &src, @@ -215,15 +219,15 @@ void index_select_scale_add(const Tensor &select_indices, bool include_last_offset) { int64_t ddim = src.size(1); auto* scale_data = scale.data_ptr(); - auto* select_indices_data = select_indices.data_ptr(); + auto* select_indices_data = select_indices.data_ptr(); auto* output_data = output.data_ptr(); if (isFastPathIndexSelectScale(src, scale, output)) { auto src_contig = src.contiguous(); auto* src_data = src_contig.data_ptr(); int64_t output_size = offsets.numel() - 1; - auto* offsets_data = offsets.data_ptr(); - std::vector offsets_include_last; + auto* offsets_data = offsets.data_ptr(); + std::vector offsets_include_last; if (include_last_offset) { output_size = offsets.numel() - 1; @@ -232,15 +236,15 @@ void index_select_scale_add(const Tensor &select_indices, offsets_include_last.resize(offsets.numel() + 1); std::memcpy( offsets_include_last.data(), - offsets.data_ptr(), - sizeof(int64_t) * offsets.numel()); + offsets.data_ptr(), + sizeof(index_t) * offsets.numel()); offsets_include_last[offsets.numel()] = select_indices.numel(); offsets_data = offsets_include_last.data(); } #ifdef USE_FBGEMM - auto kernel_fp32_i64 = - fbgemm::GenerateEmbeddingSpMDM( + auto kernel_fp32_index_t = + fbgemm::GenerateEmbeddingSpMDM( /* block_size */ddim, /* has_weight */true, /* normalize_by_lengths */false, @@ -250,9 +254,9 @@ void index_select_scale_add(const Tensor &select_indices, ); #endif at::parallel_for( - 0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) { + 0, output_size, 1, [&](index_t start_idx, index_t end_idx) { #ifdef USE_FBGEMM - kernel_fp32_i64( + kernel_fp32_index_t( /* output_size */end_idx - start_idx, /* index_size */offsets_data[end_idx] - offsets_data[start_idx], /* data_size */src.size(0), @@ -279,7 +283,7 @@ void index_select_scale_add(const Tensor &select_indices, } else { AT_ASSERT(select_indices.numel() == add_indices.numel()); auto* src_data = src.data_ptr(); - auto* add_indices_data = add_indices.data_ptr(); + auto* add_indices_data = add_indices.data_ptr(); auto src_stride0 = src.stride(0); auto src_stride1 = src.stride(1); auto output_stride0 = output.stride(0); @@ -308,7 +312,7 @@ static at::Tensor make_bag_size( const bool requires_grad) { at::Tensor bag_size; if (mode == MODE_MEAN || mode == MODE_MAX) { - bag_size = at::zeros(offsets.sizes(), indices.options()); + bag_size = at::zeros(offsets.sizes(), offsets.options()); // Compute this for MODE_MEAN and MODE_MAX (latter needed for backwards) if (offsets.size(0) != 1) { bag_size.slice(0, 0, bag_size.size(0) - 1, 1) = @@ -318,7 +322,7 @@ static at::Tensor make_bag_size( bag_size[-1] = indices.size(0) - offsets[-1]; } else if (requires_grad) { // in MODE_SUM, only allocate bag_size if we need gradients - bag_size = at::empty(offsets.sizes(), indices.options()); + bag_size = at::empty(offsets.sizes(), offsets.options()); } return bag_size; } @@ -384,35 +388,36 @@ std::tuple embedding_bag_cpu_max( } auto max_indices = at::zeros({numBags, featureSize}, indices.options()); + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_cpu_max", [&] { + auto* indices_data = indices.data_ptr(); + auto* offset2bag_data = offset2bag.data_ptr(); - auto* indices_data = indices.data_ptr(); - auto* offset2bag_data = offset2bag.data_ptr(); + auto* max_indices_data = max_indices.data_ptr(); + auto max_indices_stride = max_indices.stride(0); - auto* max_indices_data = max_indices.data_ptr(); - auto max_indices_stride = max_indices.stride(0); + auto* weight_data = weight.data_ptr(); + auto* output_data = output.data_ptr(); + auto weight_stride0 = weight.stride(0); + auto weight_stride1 = weight.stride(1); + auto output_stride = output.stride(0); - auto* weight_data = weight.data_ptr(); - auto* output_data = output.data_ptr(); - auto weight_stride0 = weight.stride(0); - auto weight_stride1 = weight.stride(1); - auto output_stride = output.stride(0); + for (int i = 0; i < numIndices; ++i) { + auto bag = offset2bag_data[i]; + auto word_idx = indices_data[i]; - for (int i = 0; i < numIndices; i++) { - auto bag = offset2bag_data[i]; - auto word_idx = indices_data[i]; + for (int dim = 0; dim < featureSize; dim++) { + auto& current_item = output_data[output_stride * bag + dim]; + auto weight_item = + weight_data[weight_stride0 * word_idx + dim * weight_stride1]; + bool is_first_for_bag = (i == 0) || offset2bag_data[i - 1] != bag; - for (int dim = 0; dim < featureSize; dim++) { - auto& current_item = output_data[output_stride * bag + dim]; - auto weight_item = - weight_data[weight_stride0 * word_idx + dim * weight_stride1]; - bool is_first_for_bag = (i == 0) || offset2bag_data[i - 1] != bag; - - if (is_first_for_bag || weight_item > current_item) { - current_item = weight_item; - max_indices_data[max_indices_stride * bag + dim] = word_idx; + if (is_first_for_bag || weight_item > current_item) { + current_item = weight_item; + max_indices_data[max_indices_stride * bag + dim] = word_idx; + } } } - } + }); return std::tuple( output, offset2bag, bag_size, max_indices); @@ -429,19 +434,23 @@ std::tuple _embedding_bag_cpu_impl( bool include_last_offset, bool requires_grad) { auto indices_arg = TensorArg(indices, "indices", 1); - checkScalarType("embedding_bag", indices_arg, kLong); + checkScalarTypes("embedding_bag", indices_arg, {kLong, kInt}); auto offsets_arg = TensorArg(offsets, "offsets", 1); - checkScalarType("embedding_bag", offsets_arg, kLong); + checkScalarTypes("embedding_bag", offsets_arg, {kLong, kInt}); + checkSameType("embedding_bag", indices_arg, offsets_arg); auto weight_arg = TensorArg(weight, "weight", 1); checkScalarTypes("embedding_bag", weight_arg, {kFloat, kDouble}); - int64_t offset_0 = offsets.data_ptr()[0]; - int64_t offset_n = offsets.data_ptr()[offsets.size(0)-1]; - TORCH_CHECK(offset_0 == 0, "offsets[0] has to be 0, i.e., the first sequence " - "in the mini-batch has to start from position 0. " - "However, got ", offsets[0]); - TORCH_CHECK(offset_n <= indices.size(0), "offsets[-1] can not " - "be greater than input's length ", indices.size(0), " but got offsets[-1] of ", - offset_n); + + AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "_embedding_bag_cpu_impl", [&]() { + index_t offset_0 = offsets.data_ptr()[0]; + index_t offset_n = offsets.data_ptr()[offsets.size(0)-1]; + TORCH_CHECK(offset_0 == 0, "offsets[0] has to be 0, i.e., the first sequence " + "in the mini-batch has to start from position 0. " + "However, got ", offsets[0]); + TORCH_CHECK(offset_n <= indices.size(0), "offsets[-1] can not " + "be greater than input's length ", indices.size(0), " but got offsets[-1] of ", + offset_n); + }); if (per_sample_weights.defined()) { TORCH_CHECK(mode == MODE_SUM, @@ -494,9 +503,9 @@ std::tuple _embedding_bag_cpu_impl( // throw out of bounds error. So to keep it simple we just add one more // entry to the end then get rid of it after make_offset2bag. offset2bag = at::zeros( - {indices.sizes()[0] + 1}, indices.options()); // offset2bag = [0 0 0 0 0] + {indices.sizes()[0] + 1}, offsets.options()); // offset2bag = [0 0 0 0 0] - make_offset2bag(offsets, indices, offset2bag); + make_offset2bag(offsets, offset2bag); offset2bag.resize_({indices.sizes()[0]}); @@ -505,14 +514,20 @@ std::tuple _embedding_bag_cpu_impl( } if (mode == MODE_MEAN || mode == MODE_SUM) { - AT_DISPATCH_FLOATING_TYPES(weight.scalar_type(), "embedding_bag_cpu", [&]() { - if (per_sample_weights.defined()) { - AT_ASSERT(mode == MODE_SUM); - index_select_scale_add( - indices, offset2bag, per_sample_weights, weight, output, offsets, include_last_offset); - } else { - index_select_add(indices, offset2bag, weight, output, offsets, include_last_offset); - } + // explicitly capture all required variables to work around windows build + // TODO: fix this when windows can correctly capture variables in nested lambda + AT_DISPATCH_FLOATING_TYPES(weight.scalar_type(), "embedding_bag_cpu", + [&indices, &offset2bag, &per_sample_weights, &weight, &output, &offsets, &include_last_offset, &mode]() { + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_cpu", + [&indices, &offset2bag, &per_sample_weights, &weight, &output, &offsets, &include_last_offset, &mode]() { + if (per_sample_weights.defined()) { + AT_ASSERT(mode == MODE_SUM); + index_select_scale_add( + indices, offset2bag, per_sample_weights, weight, output, offsets, include_last_offset); + } else { + index_select_add(indices, offset2bag, weight, output, offsets, include_last_offset); + } + }); }); auto ret = apply_bag_size(offsets, indices, mode, output, bag_size); return std::tuple(ret, offset2bag, bag_size, bag_size); @@ -598,23 +613,24 @@ Tensor _embedding_bag_backward(const Tensor &grad, const Tensor &indices, bool sparse, const Tensor& per_sample_weights) { auto indices_arg = TensorArg(indices, "indices", 1); - checkScalarType("embedding_bag", indices_arg, kLong); + checkScalarTypes("embedding_bag", indices_arg, {kLong, kInt}); checkContiguous("embedding_bag", indices_arg); auto offsets_arg = TensorArg(offsets, "offsets", 1); - checkScalarType("embedding_bag", offsets_arg, kLong); + checkScalarTypes("embedding_bag", offsets_arg, {kLong, kInt}); + checkSameType("embedding_bag", indices_arg, offsets_arg); checkContiguous("embedding_bag", offsets_arg); Tensor offset2bag_; if (indices.numel() != 0 && offset2bag.numel() == 0) { offset2bag_ = at::zeros( - {indices.sizes()[0] + 1}, indices.options()); // offset2bag = [0 0 0 0 0] + {indices.sizes()[0] + 1}, offsets.options()); // offset2bag = [0 0 0 0 0] - make_offset2bag(offsets, indices, offset2bag_); + make_offset2bag(offsets, offset2bag_); offset2bag_.resize_({indices.sizes()[0]}); } else { auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1); - checkScalarType("embedding_bag", offset2bag_arg, kLong); + checkScalarTypes("embedding_bag", offset2bag_arg, {kLong, kInt}); checkContiguous("embedding_bag", offset2bag_arg); offset2bag_ = offset2bag; } @@ -648,11 +664,12 @@ static Tensor _embedding_bag_dense_backward_cpu_max( return index_grad_weight; } -static std::vector compute_counts( +template +static std::vector compute_counts( int64_t num_weights, - int64_t* indices_data, + index_t* indices_data, int64_t indices_length) { - std::vector counts(num_weights, 0); + std::vector counts(num_weights, 0); for (int i = 0; i < indices_length; i++) { counts[indices_data[i]]++; } @@ -668,12 +685,13 @@ static std::vector compute_counts( // counts_uniq: [3, 4, 6, 7] // // The unique indices can be found at index 0, 3, 4, 6. -static std::vector compute_counts_uniq( +template +static std::vector compute_counts_uniq( int64_t num_weights, - int64_t* indices_data, + index_t* indices_data, int64_t indices_length, - const std::vector& counts) { - std::vector counts_uniq; + const std::vector& counts) { + std::vector counts_uniq; counts_uniq.reserve(num_weights); int64_t o = 0; for (int64_t i = 0; i < indices_length; i += counts[indices_data[i]]) { @@ -714,54 +732,66 @@ void _embedding_bag_dense_backward_cpu_sum_mean( per_sample_weights_stride = per_sample_weights->stride(0); } - auto* indices_data = indices.data_ptr(); - auto* offsets_data = offsets_.data_ptr(); - auto* offset2bag_data = offset2bag.data_ptr(); int64_t numel = indices.numel(); - auto counts = compute_counts(num_weights, indices_data, numel); - auto next_unique_index_idx = - compute_counts_uniq(num_weights, indices_data, numel, counts); + // explicitly capture all required variables to work around windows build + // TODO: fix this when windows can correctly capture variables in nested lambda + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "_embedding_bag_dense_backward_cpu_sum_mean", + [&indices, &offsets_, &offset2bag, &num_weights, &numel, &per_sample_weights, + &per_sample_weights_data, &per_sample_weights_stride, &mode, &scale_grad_by_freq, + &grad, &index_grad_weight] { + auto* indices_data = indices.data_ptr(); + auto* offsets_data = offsets_.data_ptr(); + auto* offset2bag_data = offset2bag.data_ptr(); - auto loop = [&](int64_t start, int64_t end) { - for (int64_t i = start; i < end; i++) { - int64_t start = i == 0 ? 0 : next_unique_index_idx[i - 1]; - int64_t index = indices_data[start]; - for (int64_t j = start; j < next_unique_index_idx[i]; j++) { - int64_t source = offset2bag_data[j]; - double scale = 1.0; - if (per_sample_weights) { - AT_ASSERT(mode == MODE_SUM); - scale = per_sample_weights_data[*per_sample_weights_stride * j]; - } - if (scale_grad_by_freq) { - scale /= counts[indices_data[i]]; - } - if (mode == 1) { // MODE_MEAN - if (offsets_.size(0) == 1) { - auto bag_size = indices.size(0); - scale /= bag_size; - } else { - if (source == offsets_.size(0) - 1) { - scale /= indices.size(0) - offsets_data[offsets_.size(0) - 1]; + auto counts = compute_counts(num_weights, indices_data, numel); + auto next_unique_index_idx = + compute_counts_uniq(num_weights, indices_data, numel, counts); + + auto loop = + [&next_unique_index_idx, &indices_data, &offset2bag_data, &per_sample_weights, + &mode, &per_sample_weights_data, &per_sample_weights_stride, &scale_grad_by_freq, + &counts, &offsets_, &indices, &offsets_data, &grad, &index_grad_weight](index_t start, index_t end) { + for (index_t i = start; i < end; i++) { + index_t start = i == 0 ? 0 : next_unique_index_idx[i - 1]; + index_t index = indices_data[start]; + for (index_t j = start; j < next_unique_index_idx[i]; j++) { + index_t source = offset2bag_data[j]; + double scale = 1.0; + if (per_sample_weights) { + AT_ASSERT(mode == MODE_SUM); + scale = per_sample_weights_data[*per_sample_weights_stride * j]; + } + if (scale_grad_by_freq) { + scale /= counts[indices_data[i]]; + } + if (mode == 1) { // MODE_MEAN + if (offsets_.size(0) == 1) { + auto bag_size = indices.size(0); + scale /= bag_size; } else { - scale /= offsets_data[source + 1] - offsets_data[source]; + if (source == offsets_.size(0) - 1) { + scale /= indices.size(0) - offsets_data[offsets_.size(0) - 1]; + } else { + scale /= offsets_data[source + 1] - offsets_data[source]; + } } } + int64_t ddim = grad.size(1); + auto igwd = index_grad_weight.data_ptr(); + auto gd = grad.data_ptr(); + THBlas_axpy(ddim, (scalar_t)scale, gd + ddim * source, 1, + igwd + ddim * index, 1); } - int64_t ddim = grad.size(1); - auto igwd = index_grad_weight.data_ptr(); - auto gd = grad.data_ptr(); - THBlas_axpy(ddim, (scalar_t)scale, gd + ddim * source, 1, - igwd + ddim * index, 1); } + }; + + if (numel > 1000) { + at::parallel_for(0, (int64_t)next_unique_index_idx.size(), 0, loop); + } else { + loop(0, (int64_t)next_unique_index_idx.size()); } - }; - if (numel > 1000) { - at::parallel_for(0, (int64_t)next_unique_index_idx.size(), 0, loop); - } else { - loop(0, (int64_t)next_unique_index_idx.size()); - } + }); } Tensor _embedding_bag_dense_backward_cpu(const Tensor &grad_, const Tensor &indices_, @@ -820,20 +850,20 @@ Tensor _embedding_bag_per_sample_weights_backward_cpu_template( auto output = at::zeros({num_samples}, grad.options()); auto indices_arg = TensorArg(indices, "indices", 1); - checkScalarType("embedding_bag", indices_arg, kLong); + checkScalarTypes("embedding_bag", indices_arg, {kLong, kInt}); checkContiguous("embedding_bag", indices_arg); Tensor offset2bag_; if (indices.numel() != 0 && offset2bag.numel() == 0) { offset2bag_ = at::zeros( - {indices.sizes()[0] + 1}, indices.options()); // offset2bag = [0 0 0 0 0] + {indices.sizes()[0] + 1}, offset2bag.options()); // offset2bag = [0 0 0 0 0] - make_offset2bag(offsets, indices, offset2bag_); + make_offset2bag(offsets, offset2bag_); offset2bag_.resize_({indices.sizes()[0]}); } else { auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1); - checkScalarType("embedding_bag", offset2bag_arg, kLong); + checkScalarTypes("embedding_bag", offset2bag_arg, {kLong, kInt}); checkContiguous("embedding_bag", offset2bag_arg); offset2bag_ = offset2bag; } @@ -846,23 +876,31 @@ Tensor _embedding_bag_per_sample_weights_backward_cpu_template( auto weight_stride0 = weight.stride(0); auto weight_stride1 = weight.stride(1); - auto* indices_data = indices.data_ptr(); + // explicitly capture all required variables to work around windows build + // TODO: fix this when windows can correctly capture variables in nested lambda + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "_embedding_bag_per_sample_weights_backward_cpu_template", + [&indices, &output, &offset2bag_, &num_samples, &embedding_features, + &grad_data, &grad_stride0, &grad_stride1, &weight_data, &weight_stride0, &weight_stride1] () { + auto* indices_data = indices.data_ptr(); - // The following are contiguous - auto* output_data = output.data_ptr(); - auto* offset2bag_data = offset2bag_.data_ptr(); + // The following are contiguous + auto* output_data = output.data_ptr(); + auto* offset2bag_data = offset2bag_.data_ptr(); - // XXX: 64 was arbitrarily chosen. There is probably a sweet spot for this number. - parallel_for(0, num_samples, 64, [&](int64_t begin, int64_t end) { - for (int64_t sample_idx = begin; sample_idx < end; sample_idx++) { - auto bag_idx = offset2bag_data[sample_idx]; - auto embedding_idx = indices_data[sample_idx]; + // XXX: 64 was arbitrarily chosen. There is probably a sweet spot for this number. + parallel_for(0, num_samples, 64, + [&embedding_features, &grad_data, &grad_stride0, &grad_stride1, &weight_data, &weight_stride0, + &weight_stride1, &offset2bag_data, &indices_data, &output_data](index_t begin, index_t end) { + for (index_t sample_idx = begin; sample_idx < end; sample_idx++) { + auto bag_idx = offset2bag_data[sample_idx]; + auto embedding_idx = indices_data[sample_idx]; - output_data[sample_idx] = dot_impl( - embedding_features, - grad_data + grad_stride0 * bag_idx, grad_stride1, - weight_data + weight_stride0 * embedding_idx, weight_stride1); - } + output_data[sample_idx] = dot_impl( + embedding_features, + grad_data + grad_stride0 * bag_idx, grad_stride1, + weight_data + weight_stride0 * embedding_idx, weight_stride1); + } + }); }); return output; } diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 39c2caea391..9abaedd9ff1 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -381,7 +381,8 @@ Tensor& index_add_cpu_(Tensor & self, int64_t dim, const Tensor & index, const T auto numel = index.numel(); TORCH_CHECK_INDEX(index.dim() <= 1, "index_add_(): Index is supposed to be a vector"); - TORCH_CHECK(index.scalar_type() == ScalarType::Long, "index_add_(): Expected dtype int64 for index"); + TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int, + "index_add_(): Expected dtype int32/int64 for index"); TORCH_CHECK(self.scalar_type() == source.scalar_type(), "index_add_(): self and source must have the same scalar type"); TORCH_CHECK(dim == 0 || dim < source.dim(), @@ -394,7 +395,6 @@ Tensor& index_add_cpu_(Tensor & self, int64_t dim, const Tensor & index, const T at::assert_no_partial_overlap(self, source); auto index_contig = index.contiguous(); - auto index_data = index_contig.data_ptr(); if (self.dim() > 1) { // Equivalent to: @@ -414,32 +414,41 @@ Tensor& index_add_cpu_(Tensor & self, int64_t dim, const Tensor & index, const T auto self_dim_size = self.size(dim); auto iter = TensorIterator::binary_op(selfSlice, selfSlice, sourceSlice); - for (auto i = 0; i < numel; i++) { - auto self_i = index_data[i]; - TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self"); - auto self_data = static_cast(selfSlice.data_ptr()) + self_i * self_stride_bytes; - auto source_data = static_cast(sourceSlice.data_ptr()) + i * source_stride_bytes; - iter.unsafe_replace_operand(0, self_data); - iter.unsafe_replace_operand(1, self_data); - iter.unsafe_replace_operand(2, source_data); - add_stub(iter.device_type(), iter, 1); - } + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cpu_", [&] () { + auto index_data = index_contig.data_ptr(); + for (auto i = 0; i < numel; i++) { + auto self_i = index_data[i]; + TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self"); + auto self_data = static_cast(selfSlice.data_ptr()) + self_i * self_stride_bytes; + auto source_data = static_cast(sourceSlice.data_ptr()) + i * source_stride_bytes; + iter.unsafe_replace_operand(0, self_data); + iter.unsafe_replace_operand(1, self_data); + iter.unsafe_replace_operand(2, source_data); + add_stub(iter.device_type(), iter, 1); + } + }); } else { TORCH_CHECK(source.dim() <= 1, "source.dim() (", source.dim(), ") must one or zero for given self.dim() (", self.dim(), ")"); - AT_DISPATCH_ALL_TYPES(self.scalar_type(), "index_add_", [&] { + // explicitly capture all required variables to work around windows build + // TODO: fix this when windows can correctly capture variables in nested lambda + AT_DISPATCH_ALL_TYPES(self.scalar_type(), "index_add_", [&self, &source, &dim, &index_contig, &numel] { auto self_stride = self.dim() == 0 ? 1 : self.stride(dim); auto source_stride = source.dim() == 0 ? 1 : source.stride(dim); // TODO: Maybe TensorAccessor can beused here? auto* self_ptr = self.data_ptr(); auto* source_ptr = source.data_ptr(); - for (auto i = 0; i < numel; i++) { - auto self_i = index_data[i]; - TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self.numel()), "index out of range in self"); - scalar_t *self_ip = self_ptr + self_i * self_stride; - *self_ip += *(source_ptr + i * source_stride); - } + AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_add_cpu_", + [&index_contig, &numel, &self, &self_ptr, &self_stride, &source_ptr, &source_stride] { + auto index_data = index_contig.data_ptr(); + for (auto i = 0; i < numel; i++) { + auto self_i = index_data[i]; + TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self.numel()), "index out of range in self"); + scalar_t *self_ip = self_ptr + self_i * self_stride; + *self_ip += *(source_ptr + i * source_stride); + } + }); }); } return self; @@ -454,7 +463,7 @@ Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim auto numel = index.numel(); TORCH_CHECK_INDEX(index.dim() <= 1, "index_select(): Index is supposed to be a vector"); - TORCH_CHECK(index.scalar_type() == ScalarType::Long, "index_select(): Expected dtype int64 for index"); + TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int, "index_select(): Expected dtype int32 or int64 for index"); TORCH_CHECK(self.scalar_type() == result.scalar_type(), "index_select(): self and result must have the same scalar type"); TORCH_CHECK(dim == 0 || dim < self.dim(), @@ -468,7 +477,6 @@ Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim result.resize_(result_size); auto index_contig = index.contiguous(); - auto index_data = index_contig.data_ptr(); if (self.dim() > 1) { if (numel == 0 || self.numel() == 0) { @@ -492,17 +500,26 @@ Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim .build(); auto grain_size = at::internal::GRAIN_SIZE; - auto outer_loop = [&](int64_t start, int64_t end) { + auto outer_loop = + // explicitly capture all required variables to work around windows build + // TODO: fix this when windows can correctly capture variables in nested lambda + [&index_contig, &iter, &self_dim_size, &selfSlice_data, &self_stride_bytes, &resultSlice_data, + &result_stride_bytes](int64_t start, int64_t end) { auto sub_iter = TensorIterator(iter); - for (int64_t i = start; i < end; i++) { - auto self_i = index_data[i]; - TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self"); - auto self_data = static_cast(selfSlice_data) + self_i * self_stride_bytes; - auto result_data = static_cast(resultSlice_data) + i * result_stride_bytes; - sub_iter.unsafe_replace_operand(0, result_data); - sub_iter.unsafe_replace_operand(1, self_data); - copy_stub(sub_iter.device_type(), sub_iter, false); - } + AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_", + [&index_contig, &start, &end, &sub_iter, &self_dim_size, &selfSlice_data, &self_stride_bytes, + &resultSlice_data, &result_stride_bytes] () { + auto index_data = index_contig.data_ptr(); + for (int64_t i = start; i < end; i++) { + auto self_i = index_data[i]; + TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self"); + auto self_data = static_cast(selfSlice_data) + self_i * self_stride_bytes; + auto result_data = static_cast(resultSlice_data) + i * result_stride_bytes; + sub_iter.unsafe_replace_operand(0, result_data); + sub_iter.unsafe_replace_operand(1, self_data); + copy_stub(sub_iter.device_type(), sub_iter, false); + }; + }); }; // parallel on inner loop in case the slice is large enough; @@ -513,14 +530,23 @@ Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim // use a fast loop when self and result are contiguous and of the same data type if (iter.is_contiguous() && self.scalar_type() == result.scalar_type()) { auto slice_size_bytes = slice_size * elementSize(self.scalar_type()); - at::parallel_for(0, numel, grain_size / slice_size, [&](int64_t start, int64_t end) { - for (int64_t i = start; i < end; i++) { - auto self_i = index_data[i]; - TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self"); - auto self_data = static_cast(selfSlice_data) + self_i * self_stride_bytes; - auto result_data = static_cast(resultSlice_data) + i * result_stride_bytes; - memcpy(result_data, self_data, slice_size_bytes); - } + // explicitly capture all required variables to work around windows build + // TODO: fix this when windows can correctly capture variables in nested lambda + at::parallel_for(0, numel, grain_size / slice_size, + [&index_contig, &slice_size_bytes, &self_dim_size, &selfSlice_data, + &self_stride_bytes, &resultSlice_data, &result_stride_bytes](int64_t start, int64_t end) { + AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_", + [&index_contig, &slice_size_bytes, &self_dim_size, &selfSlice_data, + &self_stride_bytes, &resultSlice_data, &result_stride_bytes, &start, &end] () { + auto index_data = index_contig.data_ptr(); + for (int64_t i = start; i < end; i++) { + auto self_i = index_data[i]; + TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self"); + auto self_data = static_cast(selfSlice_data) + self_i * self_stride_bytes; + auto result_data = static_cast(resultSlice_data) + i * result_stride_bytes; + memcpy(result_data, self_data, slice_size_bytes); + } + }); }); } else { at::parallel_for(0, numel, grain_size / slice_size, outer_loop); @@ -528,20 +554,26 @@ Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim } } else { TORCH_CHECK(result.dim() <= 1, "result.dim() (", result.dim(), ") must one or zero for given self.dim() (", self.dim(), ")"); - - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, self.scalar_type(), "index_select", [&] { + // explicitly capture all required variables to work around windows build + // TODO: fix this when windows can correctly capture variables in nested lambda + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, self.scalar_type(), "index_select", + [&index_contig, &self, &result, &dim, &numel] { auto self_stride = self.dim() == 0 ? 1 : self.stride(dim); auto result_stride = result.dim() == 0 ? 1 : result.stride(dim); auto self_data_ptr = self.data_ptr(); auto result_data_ptr = result.data_ptr(); auto self_numel = self.numel(); - for (auto i = 0; i < numel; i++) { - auto self_i = index_data[i]; - TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_numel), "index out of range in self"); - scalar_t *self_ip = self_data_ptr + self_i * self_stride; - *(result_data_ptr + i * result_stride) = *self_ip; - } + AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_", + [&index_contig, &numel, &self_numel, &self_data_ptr, &self_stride, &result_data_ptr, &result_stride] { + auto index_data = index_contig.data_ptr(); + for (auto i = 0; i < numel; i++) { + auto self_i = index_data[i]; + TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_numel), "index out of range in self"); + scalar_t *self_ip = self_data_ptr + self_i * self_stride; + *(result_data_ptr + i * result_stride) = *self_ip; + } + }); }); } diff --git a/aten/src/ATen/native/cuda/Embedding.cu b/aten/src/ATen/native/cuda/Embedding.cu index cab8483093d..297b99abcd6 100644 --- a/aten/src/ATen/native/cuda/Embedding.cu +++ b/aten/src/ATen/native/cuda/Embedding.cu @@ -29,9 +29,10 @@ static const int BLOCKDIMY = 32; template + typename accscalar_t, + typename index_t> __global__ void embedding_backward_feature_kernel - (int64_t* indices, + (index_t* indices, const scalar_t* __restrict__ grad, scalar_t* __restrict__ grad_weight, int n, // OK to pass as int, we don't expect 2 billion+ samples in one shot @@ -117,10 +118,10 @@ __global__ void embedding_backward_feature_kernel } -template +template __global__ void embedding_backward_kernel( - int64_t* input, int64_t* indices, scalar_t* grad_output, scalar_t* grad_weight, - int64_t* count, int64_t numel, int64_t stride, int padding_idx) { + index_t* input, index_t* indices, scalar_t* grad_output, scalar_t* grad_weight, + index_t* count, int64_t numel, int64_t stride, int padding_idx) { using accscalar_t = acc_type; int idx = blockIdx.x * 4 + threadIdx.y; @@ -179,9 +180,9 @@ __global__ void embedding_backward_kernel( } /* Calculate norms of the rows of weight_ptr given by idx_ptr and capture them in norms */ -template +template __global__ void renorm_kernel( - scalar_t* weights, int64_t* indices, accscalar_t max_norm, + scalar_t* weights, index_t* indices, accscalar_t max_norm, accscalar_t norm_type, int64_t dim, int64_t weights_stride0, int64_t weights_stride1) { @@ -228,7 +229,7 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice bool scale_grad_by_freq) { auto grad_arg = TensorArg(grad_, "grad", 1); auto indices_arg = TensorArg(indices, "indices", 1); - checkScalarType("embedding_backward", indices_arg, kLong); + checkScalarTypes("embedding_backward", indices_arg, {kLong, kInt}); checkSameGPU("embedding_backward", grad_arg, indices_arg); auto num_indices = indices.numel(); @@ -250,18 +251,20 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice { AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "embedding_backward", [&] { using accscalar_t = acc_type; - embedding_backward_feature_kernel - <<>> - (indices_contig.data_ptr(), - grad.data_ptr(), - grad_weight.data_ptr(), - static_cast(num_indices), - static_cast(stride), - static_cast(padding_idx)); - TORCH_CUDA_KERNEL_LAUNCH_CHECK(); + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () { + embedding_backward_feature_kernel + <<>> + (indices_contig.data_ptr(), + grad.data_ptr(), + grad_weight.data_ptr(), + static_cast(num_indices), + static_cast(stride), + static_cast(padding_idx)); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); return grad_weight; @@ -269,61 +272,63 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice auto sorted_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); auto orig_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - using device_ptr = thrust::device_ptr; - - // Sort the inputs into sorted with the corresponding indices; we - // don't need a stable or multidimensional sort, so just use Thrust - // directly - { - sorted_indices.copy_(indices); - - auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); - auto policy = thrust::cuda::par(allocator).on(stream); - - // Fill sortedOrigIndices with sequential indices - auto count_iter = thrust::counting_iterator(0); - auto orig_data = device_ptr(orig_indices.data_ptr()); - thrust::copy(policy, count_iter, count_iter + num_indices, orig_data); - - // Sort; a stable sort is not required - auto sorted_data = device_ptr(sorted_indices.data_ptr()); - thrust::sort_by_key(policy, sorted_data, sorted_data + num_indices, orig_data, - ThrustLTOp()); - } - Tensor count; - if (scale_grad_by_freq) { - count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () { + using device_ptr = thrust::device_ptr; - auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); - auto policy = thrust::cuda::par(allocator).on(stream); + // Sort the inputs into sorted with the corresponding indices; we + // don't need a stable or multidimensional sort, so just use Thrust + // directly + { + sorted_indices.copy_(indices); - // Compute an increasing sequence per unique item in sortedIndices: - // sorted: 2 5 5 5 7 7 8 9 9 - // count: 1 1 2 3 1 2 1 1 2 - auto sorted_data = device_ptr(sorted_indices.data_ptr()); - auto count_data = device_ptr(count.data_ptr()); - thrust::inclusive_scan_by_key( - policy, - sorted_data, - sorted_data + num_indices, - thrust::make_constant_iterator(1), - count_data - ); + auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); + auto policy = thrust::cuda::par(allocator).on(stream); - // Take the maximum of each count per unique key in reverse: - // sorted: 2 5 5 5 7 7 8 9 9 - // count: 1 3 3 3 2 2 1 2 2 - thrust::inclusive_scan_by_key( - policy, - thrust::make_reverse_iterator(sorted_data + num_indices), - thrust::make_reverse_iterator(sorted_data), - thrust::make_reverse_iterator(count_data + num_indices), - thrust::make_reverse_iterator(count_data + num_indices), - thrust::equal_to(), - thrust::maximum() - ); - } + // Fill sortedOrigIndices with sequential indices + auto count_iter = thrust::counting_iterator(0); + auto orig_data = device_ptr(orig_indices.data_ptr()); + thrust::copy(policy, count_iter, count_iter + num_indices, orig_data); + + // Sort; a stable sort is not required + auto sorted_data = device_ptr(sorted_indices.data_ptr()); + thrust::sort_by_key(policy, sorted_data, sorted_data + num_indices, orig_data, + ThrustLTOp()); + } + + if (scale_grad_by_freq) { + count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + + auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); + auto policy = thrust::cuda::par(allocator).on(stream); + + // Compute an increasing sequence per unique item in sortedIndices: + // sorted: 2 5 5 5 7 7 8 9 9 + // count: 1 1 2 3 1 2 1 1 2 + auto sorted_data = device_ptr(sorted_indices.data_ptr()); + auto count_data = device_ptr(count.data_ptr()); + thrust::inclusive_scan_by_key( + policy, + sorted_data, + sorted_data + num_indices, + thrust::make_constant_iterator(1), + count_data + ); + + // Take the maximum of each count per unique key in reverse: + // sorted: 2 5 5 5 7 7 8 9 9 + // count: 1 3 3 3 2 2 1 2 2 + thrust::inclusive_scan_by_key( + policy, + thrust::make_reverse_iterator(sorted_data + num_indices), + thrust::make_reverse_iterator(sorted_data), + thrust::make_reverse_iterator(count_data + num_indices), + thrust::make_reverse_iterator(count_data + num_indices), + thrust::equal_to(), + thrust::maximum() + ); + } + }); return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices, count, num_weights, padding_idx); @@ -340,31 +345,33 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices, auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); auto policy = thrust::cuda::par(allocator).on(stream); - using device_ptr = thrust::device_ptr; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_renorm_cuda_", [&] () { + using device_ptr = thrust::device_ptr; - auto num_indices = indices.numel(); - auto indices_contig = std::get<0>(indices.sort()).contiguous(); - auto indices_data = device_ptr(indices_contig.data_ptr()); + auto num_indices = indices.numel(); + auto indices_contig = std::get<0>(indices.sort()).contiguous(); + auto indices_data = device_ptr(indices_contig.data_ptr()); - auto unique_indices = at::empty(indices.numel(), indices.options()); - auto unique_data = device_ptr(unique_indices.data_ptr()); - auto end = thrust::unique_copy(policy, indices_data, indices_data + num_indices, unique_data); - auto num_unique_indices = static_cast(end - unique_data); + auto unique_indices = at::empty(indices.numel(), indices.options()); + auto unique_data = device_ptr(unique_indices.data_ptr()); + auto end = thrust::unique_copy(policy, indices_data, indices_data + num_indices, unique_data); + auto num_unique_indices = static_cast(end - unique_data); - dim3 grid(num_unique_indices); - dim3 block(128); - int dim = self.stride(0); + dim3 grid(num_unique_indices); + dim3 block(128); + int dim = self.stride(0); - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "embedding_backward", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "embedding_backward", [&] { - using accscalar_t = acc_type; - renorm_kernel<<>>( - self.data_ptr(), - unique_indices.data_ptr(), - static_cast(max_norm), - static_cast(norm_type), - dim, self.stride(0), self.stride(1)); - TORCH_CUDA_KERNEL_LAUNCH_CHECK(); + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "embedding_backward", [&] { + AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "embedding_backward", [&] { + using accscalar_t = acc_type; + renorm_kernel<<>>( + self.data_ptr(), + unique_indices.data_ptr(), + static_cast(max_norm), + static_cast(norm_type), + dim, self.stride(0), self.stride(1)); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); return self; diff --git a/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu b/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu index 0fd742d7b70..61f4ed72def 100644 --- a/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu +++ b/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu @@ -40,8 +40,9 @@ int64_t ceil_div(int64_t x, int64_t y) { return (x + y - 1) / y; } +template __global__ -void krn_partials_per_segment(int64_t *ret, const int64_t *segment_offsets, +void krn_partials_per_segment(index_t *ret, const index_t *segment_offsets, int64_t num_of_segments, int64_t numel) { const int id = blockIdx.x * blockDim.x + threadIdx.x; if(id < num_of_segments) { @@ -52,18 +53,19 @@ void krn_partials_per_segment(int64_t *ret, const int64_t *segment_offsets, } } +template __global__ void krn_partial_segment_offset( - int64_t *ret, - const int64_t *partials_per_segment, - const int64_t *partials_per_segment_offset, - const int64_t *segment_offsets, + index_t *ret, + const index_t *partials_per_segment, + const index_t *partials_per_segment_offset, + const index_t *segment_offsets, int64_t num_of_segments) { const int id = blockIdx.x * blockDim.x + threadIdx.x; if(id < num_of_segments) { - int64_t idx = partials_per_segment_offset[id]; - const int64_t num_partials = partials_per_segment[id]; - const int64_t segment_offset = segment_offsets[id]; + index_t idx = partials_per_segment_offset[id]; + const index_t num_partials = partials_per_segment[id]; + const index_t segment_offset = segment_offsets[id]; for (int64_t i=0; i +template __global__ void compute_grad_weight_bags( - int64_t *indices, scalar_t *gradOutput, - int64_t *offset2bag, int64_t *count, ptrdiff_t numel, - int64_t stride, int mode_mean, const int64_t *bag_size, + index_t *indices, scalar_t *gradOutput, + index_t *offset2bag, index_t *count, ptrdiff_t numel, + int64_t stride, int mode_mean, const index_t *bag_size, scalar_t* per_sample_weights, int64_t per_sample_weights_stride, - int64_t* segment_offsets, int64_t num_of_segments, + index_t* segment_offsets, int64_t num_of_segments, acc_type *grad_weight_per_segment, const int64_t stride_warped) { @@ -113,14 +115,14 @@ __global__ void compute_grad_weight_bags( grad_weight_per_segment[id * stride + startFeature] = weight; } -template +template __global__ void compute_grad_weight( - int64_t *indices, + index_t *indices, scalar_t *gradOutput, - int64_t *count, + index_t *count, ptrdiff_t numel, int64_t stride, - int64_t* segment_offsets, + index_t* segment_offsets, int64_t num_of_segments, acc_type *grad_weight_per_segment, const int64_t stride_warped) { @@ -140,7 +142,7 @@ __global__ void compute_grad_weight( accscalar_t weight = 0; for (int idx=idx_begin; idx < idx_end; ++idx) { - const int64_t target_row = indices[idx]; + const index_t target_row = indices[idx]; const accscalar_t scale = count ? (accscalar_t)1.0 / count[idx] : 1.0; weight += gradOutput[target_row * stride + startFeature] * scale; } @@ -148,12 +150,12 @@ __global__ void compute_grad_weight( } // This kernel assumes that all input tensors are contiguous. -template +template __global__ void sum_and_scatter( - int64_t *input, scalar_t *gradWeight, int64_t stride, - int64_t* segment_offsets, int64_t num_of_segments, + index_t *input, scalar_t *gradWeight, int64_t stride, + index_t* segment_offsets, int64_t num_of_segments, const acc_type *grad_weight_per_segment, - const int64_t *segment_sizes_offsets, int64_t num_of_partial_segments, + const index_t *segment_sizes_offsets, int64_t num_of_partial_segments, const int64_t padding_idx, const int64_t stride_warped) { @@ -206,118 +208,120 @@ Tensor embedding_backward_cuda_kernel( // spawn a warp per index. In this context, a segment is a number of rows that should // be summarized. // Unit: index in `sorted_indices` and `orig_indices` - auto segment_offsets = at::empty({numel}, orig_indices.options()); - int64_t num_of_segments; - { - auto sorted_indices_dev = thrust::device_ptr(sorted_indices.data_ptr()); - auto dummy = at::empty_like(sorted_indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - auto dummy_dev = thrust::device_ptr(dummy.data_ptr()); - auto ends = thrust::unique_by_key_copy( + AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () { + auto segment_offsets = at::empty({numel}, orig_indices.options()); + int64_t num_of_segments; + { + auto sorted_indices_dev = thrust::device_ptr(sorted_indices.data_ptr()); + auto dummy = at::empty_like(sorted_indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto dummy_dev = thrust::device_ptr(dummy.data_ptr()); + auto ends = thrust::unique_by_key_copy( + policy, + sorted_indices_dev, + sorted_indices_dev + numel, + thrust::make_counting_iterator(0), + dummy_dev, + thrust::device_ptr(segment_offsets.data_ptr())); + num_of_segments = thrust::get<0>(ends) - dummy_dev; + } + + // We split the segments up into sizes of `NROWS_PER_THREAD` + // Compute the number partial-segments per segment (some partial-segments + // may not be the full `NROWS_PER_THREAD` number of rows) + auto partials_per_segment = at::empty({num_of_segments}, orig_indices.options()); + { + krn_partials_per_segment<<>> ( + partials_per_segment.data_ptr(), + segment_offsets.data_ptr(), + num_of_segments, + numel); + } + + // In order to compute `partial_segment_offset`, which is the start index + // of each partial-segment in `sorted_indices`, we need to compute the + // start position of each _segment_ in `partial_segment_offset`. + // Unit: index in `partial_segment_offset` + auto partials_per_segment_offset = at::empty({num_of_segments}, orig_indices.options()); + thrust::exclusive_scan( policy, - sorted_indices_dev, - sorted_indices_dev + numel, - thrust::make_counting_iterator(0), - dummy_dev, - thrust::device_ptr(segment_offsets.data_ptr())); - num_of_segments = thrust::get<0>(ends) - dummy_dev; - } + thrust::device_ptr(partials_per_segment.data_ptr()), + thrust::device_ptr(partials_per_segment.data_ptr()+num_of_segments), + thrust::device_ptr(partials_per_segment_offset.data_ptr())); - // We split the segments up into sizes of `NROWS_PER_THREAD` - // Compute the number partial-segments per segment (some partial-segments - // may not be the full `NROWS_PER_THREAD` number of rows) - auto partials_per_segment = at::empty({num_of_segments}, orig_indices.options()); - { - krn_partials_per_segment<<>> ( - partials_per_segment.data_ptr(), - segment_offsets.data_ptr(), - num_of_segments, - numel); - } + // The total number of partial-segments is the sum of `partials_per_segment_offset` + const int num_of_partial_segments = partials_per_segment[num_of_segments-1].item() + + partials_per_segment_offset[num_of_segments-1].item(); - // In order to compute `partial_segment_offset`, which is the start index - // of each partial-segment in `sorted_indices`, we need to compute the - // start position of each _segment_ in `partial_segment_offset`. - // Unit: index in `partial_segment_offset` - auto partials_per_segment_offset = at::empty({num_of_segments}, orig_indices.options()); - thrust::exclusive_scan( - policy, - thrust::device_ptr(partials_per_segment.data_ptr()), - thrust::device_ptr(partials_per_segment.data_ptr()+num_of_segments), - thrust::device_ptr(partials_per_segment_offset.data_ptr())); + // Now we can compute the start position of each partial-segment + // Unit: index in `sorted_indices` and `orig_indices` + auto partial_segment_offset = at::empty({num_of_partial_segments}, orig_indices.options()); + { + krn_partial_segment_offset<<>> ( + partial_segment_offset.data_ptr(), + partials_per_segment.data_ptr(), + partials_per_segment_offset.data_ptr(), + segment_offsets.data_ptr(), + num_of_segments); + } - // The total number of partial-segments is the sum of `partials_per_segment_offset` - const int num_of_partial_segments = partials_per_segment[num_of_segments-1].item() + - partials_per_segment_offset[num_of_segments-1].item(); + const int stride_warped = ceil_div(stride, C10_WARP_SIZE)*C10_WARP_SIZE; + const int block = std::min(stride_warped, MAX_BLOCK_SIZE); + const int grid = ceil_div(num_of_partial_segments*stride_warped, block); - // Now we can compute the start position of each partial-segment - // Unit: index in `sorted_indices` and `orig_indices` - auto partial_segment_offset = at::empty({num_of_partial_segments}, orig_indices.options()); - { - krn_partial_segment_offset<<>> ( - partial_segment_offset.data_ptr(), - partials_per_segment.data_ptr(), - partials_per_segment_offset.data_ptr(), - segment_offsets.data_ptr(), - num_of_segments); - } + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, + grad.scalar_type(), "embedding_bag_backward_cuda_compute_grad_weight", [&] { + AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "embedding_bag_backward_cuda_compute_grad_weight", [&] { + // For numerical stability, the dtype of `grad_weight_per_segment` + // should match `acc_type` + using partial_weight_t = acc_type; + TensorOptions op; + if(grad.dtype() == at::kHalf || grad.dtype() == at::kBFloat16) { + op = grad.options().dtype(at::kFloat); + } else { + op = grad.options(); + } + auto grad_weight_per_segment = at::empty({num_of_partial_segments, stride}, op); + // Compute the sum of each partial-segment and handle bags + if (offset2bag.defined()) { + compute_grad_weight_bags<<>>( + orig_indices.data_ptr(), + grad.data_ptr(), + offset2bag.data_ptr(), + count.defined() ? count.data_ptr() : nullptr, numel, stride, + mode_mean, bag_size.data_ptr(), + per_sample_weights.defined() ? per_sample_weights.data_ptr() : NULL, + per_sample_weights.defined() ? per_sample_weights.stride(0) : 0, + partial_segment_offset.data_ptr(), + num_of_partial_segments, grad_weight_per_segment.data_ptr(), + stride_warped); + } else { + compute_grad_weight<<>>( + orig_indices.data_ptr(), + grad.data_ptr(), + count.defined() ? count.data_ptr() : nullptr, + numel, stride, + partial_segment_offset.data_ptr(), + num_of_partial_segments, + grad_weight_per_segment.data_ptr(), + stride_warped); + } + AT_CUDA_CHECK(cudaGetLastError()); - const int stride_warped = ceil_div(stride, C10_WARP_SIZE)*C10_WARP_SIZE; - const int block = std::min(stride_warped, MAX_BLOCK_SIZE); - const int grid = ceil_div(num_of_partial_segments*stride_warped, block); - - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, - grad.scalar_type(), "embedding_bag_backward_cuda_compute_grad_weight", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "embedding_bag_backward_cuda_compute_grad_weight", [&] { - // For numerical stability, the dtype of `grad_weight_per_segment` - // should match `acc_type` - using partial_weight_t = acc_type; - TensorOptions op; - if(grad.dtype() == at::kHalf || grad.dtype() == at::kBFloat16) { - op = grad.options().dtype(at::kFloat); - } else { - op = grad.options(); - } - auto grad_weight_per_segment = at::empty({num_of_partial_segments, stride}, op); - // Compute the sum of each partial-segment and handle bags - if (offset2bag.defined()) { - compute_grad_weight_bags<<>>( - orig_indices.data_ptr(), - grad.data_ptr(), - offset2bag.data_ptr(), - count.defined() ? count.data_ptr() : nullptr, numel, stride, - mode_mean, bag_size.data_ptr(), - per_sample_weights.defined() ? per_sample_weights.data_ptr() : NULL, - per_sample_weights.defined() ? per_sample_weights.stride(0) : 0, - partial_segment_offset.data_ptr(), - num_of_partial_segments, grad_weight_per_segment.data_ptr(), - stride_warped); - } else { - compute_grad_weight<<>>( - orig_indices.data_ptr(), - grad.data_ptr(), - count.defined() ? count.data_ptr() : nullptr, - numel, stride, - partial_segment_offset.data_ptr(), + // Finally, we sum all the partial-sums and scatter them + // into `grad_weight`. + const int grid2 = ceil_div(num_of_segments*stride_warped, block); + sum_and_scatter<<>>( + sorted_indices.data_ptr(), + grad_weight.data_ptr(), + stride, + segment_offsets.data_ptr(), + num_of_segments, grad_weight_per_segment.data_ptr(), + partials_per_segment_offset.data_ptr(), num_of_partial_segments, - grad_weight_per_segment.data_ptr(), + padding_idx, stride_warped); - } - AT_CUDA_CHECK(cudaGetLastError()); - - // Finally, we sum all the partial-sums and scatter them - // into `grad_weight`. - const int grid2 = ceil_div(num_of_segments*stride_warped, block); - sum_and_scatter<<>>( - sorted_indices.data_ptr(), - grad_weight.data_ptr(), - stride, - segment_offsets.data_ptr(), - num_of_segments, grad_weight_per_segment.data_ptr(), - partials_per_segment_offset.data_ptr(), - num_of_partial_segments, - padding_idx, - stride_warped); - AT_CUDA_CHECK(cudaGetLastError()); + AT_CUDA_CHECK(cudaGetLastError()); + }); }); }); return grad_weight; diff --git a/aten/src/ATen/native/cuda/EmbeddingBag.cu b/aten/src/ATen/native/cuda/EmbeddingBag.cu index d128f2f63ca..51755cc9f0f 100644 --- a/aten/src/ATen/native/cuda/EmbeddingBag.cu +++ b/aten/src/ATen/native/cuda/EmbeddingBag.cu @@ -31,12 +31,12 @@ constexpr int MODE_MAX = 2; // This kernel assumes that all input tensors except `weight` and // per_sample_weights are contiguous. -template +template __global__ void EmbeddingBag_updateOutputKernel( - int64_t *input, int64_t *offsets, scalar_t *weight, scalar_t *output, - int64_t *offset2bag, int64_t numIndices, int64_t numBags, + index_t *input, index_t *offsets, scalar_t *weight, scalar_t *output, + index_t *offset2bag, int64_t numIndices, int64_t numBags, int64_t featureSize, int64_t weight_stride0, int64_t weight_stride1, - int mode, int64_t *bag_size, int64_t *max_indices, + int mode, index_t *bag_size, index_t *max_indices, scalar_t* per_sample_weights, int64_t per_sample_weights_stride) { // the strategy here is that each bag x feature is handled by a single thread @@ -135,62 +135,65 @@ Tensor embedding_bag_backward_cuda_sum_avg( auto sorted_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); auto orig_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - using device_ptr = thrust::device_ptr; - - // Sort the inputs into sorted with the corresponding indices; we - // don't need a stable or multidimensional sort, so just use Thrust - // directly - { - sorted_indices.copy_(indices); - - auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); - auto policy = thrust::cuda::par(allocator).on(stream); - - // Fill sortedOrigIndices with sequential indices - auto count_iter = thrust::counting_iterator(0); - auto orig_data = device_ptr(orig_indices.data_ptr()); - thrust::copy(policy, count_iter, count_iter + numel, orig_data); - - // Sort; a stable sort is not required - auto sorted_data = device_ptr(sorted_indices.data_ptr()); - thrust::sort_by_key(policy, sorted_data, sorted_data + numel, orig_data, - ThrustLTOp()); - } - Tensor count; - if (scale_grad_by_freq) { - count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); - auto policy = thrust::cuda::par(allocator).on(stream); + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () { + using device_ptr = thrust::device_ptr; - // Compute an increasing sequence per unique item in sortedIndices: - // sorted: 2 5 5 5 7 7 8 9 9 - // count: 1 1 2 3 1 2 1 1 2 - auto sorted_data = device_ptr(sorted_indices.data_ptr()); - auto count_data = device_ptr(count.data_ptr()); - thrust::inclusive_scan_by_key(policy, sorted_data, sorted_data + numel, - thrust::make_constant_iterator(1), - count_data); + // Sort the inputs into sorted with the corresponding indices; we + // don't need a stable or multidimensional sort, so just use Thrust + // directly + { + sorted_indices.copy_(indices); - // Take the maximum of each count per unique key in reverse: - // sorted: 2 5 5 5 7 7 8 9 9 - // count: 1 3 3 3 2 2 1 2 2 - thrust::inclusive_scan_by_key( - policy, thrust::make_reverse_iterator(sorted_data + numel), - thrust::make_reverse_iterator(sorted_data), - thrust::make_reverse_iterator(count_data + numel), - thrust::make_reverse_iterator(count_data + numel), - thrust::equal_to(), thrust::maximum()); - } + auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); + auto policy = thrust::cuda::par(allocator).on(stream); + + // Fill sortedOrigIndices with sequential indices + auto count_iter = thrust::counting_iterator(0); + auto orig_data = device_ptr(orig_indices.data_ptr()); + thrust::copy(policy, count_iter, count_iter + numel, orig_data); + + // Sort; a stable sort is not required + auto sorted_data = device_ptr(sorted_indices.data_ptr()); + thrust::sort_by_key(policy, sorted_data, sorted_data + numel, orig_data, + ThrustLTOp()); + } + + if (scale_grad_by_freq) { + count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + + auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); + auto policy = thrust::cuda::par(allocator).on(stream); + + // Compute an increasing sequence per unique item in sortedIndices: + // sorted: 2 5 5 5 7 7 8 9 9 + // count: 1 1 2 3 1 2 1 1 2 + auto sorted_data = device_ptr(sorted_indices.data_ptr()); + auto count_data = device_ptr(count.data_ptr()); + thrust::inclusive_scan_by_key(policy, sorted_data, sorted_data + numel, + thrust::make_constant_iterator(1), + count_data); + + // Take the maximum of each count per unique key in reverse: + // sorted: 2 5 5 5 7 7 8 9 9 + // count: 1 3 3 3 2 2 1 2 2 + thrust::inclusive_scan_by_key( + policy, thrust::make_reverse_iterator(sorted_data + numel), + thrust::make_reverse_iterator(sorted_data), + thrust::make_reverse_iterator(count_data + numel), + thrust::make_reverse_iterator(count_data + numel), + thrust::equal_to(), thrust::maximum()); + } + }); return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices, count, num_weights, /* padding_idx= */ -1, scale_grad_by_freq, mode == MODE_MEAN, offset2bag, bag_size, per_sample_weights); } -template +template __global__ void EmbeddingBag_accGradParametersKernel_max( - int64_t *max_indices, scalar_t *gradOutput, + index_t *max_indices, scalar_t *gradOutput, scalar_t *gradWeight, int64_t stride, int64_t numBags) { using accscalar_t = acc_type; @@ -205,7 +208,7 @@ __global__ void EmbeddingBag_accGradParametersKernel_max( if (featureDim < stride) { int64_t bag = chunk / chunksPerBag; - int64_t word_idx = max_indices[bag * stride + featureDim]; + index_t word_idx = max_indices[bag * stride + featureDim]; if (word_idx >= 0) { // If bag is empty, we have max_indices[idx] set to -1 in forward. gpuAtomicAdd(&(gradWeight[word_idx * stride + featureDim]), @@ -236,10 +239,12 @@ Tensor embedding_bag_backward_cuda_max(const Tensor &grad, AT_DISPATCH_FLOATING_TYPES_AND_HALF( grad.scalar_type(), "embedding_bag_backward_cuda_max", [&] { - EmbeddingBag_accGradParametersKernel_max< - scalar_t><<>>( - max_indices.data_ptr(), grad.data_ptr(), - grad_weight.data_ptr(), stride, numBags); + AT_DISPATCH_INDEX_TYPES(max_indices.scalar_type(), "embedding_bag_backward_cuda_max", [&] () { + EmbeddingBag_accGradParametersKernel_max< + scalar_t, index_t><<>>( + max_indices.data_ptr(), grad.data_ptr(), + grad_weight.data_ptr(), stride, numBags); + }); }); AT_CUDA_CHECK(cudaGetLastError()); @@ -275,9 +280,10 @@ _embedding_bag_cuda(const Tensor &weight, const Tensor &indices, const Tensor& per_sample_weights, bool include_last_offset) { auto indices_arg = TensorArg(indices, "indices", 1); - checkScalarType("embedding_bag_cuda", indices_arg, kLong); + checkScalarTypes("embedding_bag_cuda", indices_arg, {kLong, kInt}); auto offsets_arg = TensorArg(offsets, "offsets", 1); - checkScalarType("embedding_bag_cuda", offsets_arg, kLong); + checkScalarTypes("embedding_bag_cuda", offsets_arg, {kLong, kInt}); + checkSameType("embedding_bag_cuda", indices_arg, offsets_arg); auto weight_arg = TensorArg(weight, "weight", 1); checkSameGPU("embedding_bag_cuda", weight_arg, indices_arg); checkSameGPU("embedding_bag_cuda", weight_arg, offsets_arg); @@ -320,14 +326,16 @@ _embedding_bag_cuda(const Tensor &weight, const Tensor &indices, int grid = 1024; AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, weight.scalar_type(), "embedding_bag_cuda", [&] { AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "embedding_bag_cuda", [&] { - EmbeddingBag_updateOutputKernel<<>>( - indices.data_ptr(), offsets.data_ptr(), - weight.data_ptr(), output.data_ptr(), - offset2bag.data_ptr(), numIndices, numBags, featureSize, - weight.stride(0), weight.stride(1), mode, bag_size.data_ptr(), - mode == MODE_MAX ? max_indices.data_ptr() : NULL, - per_sample_weights.defined() ? per_sample_weights.data_ptr() : NULL, - per_sample_weights.defined() ? per_sample_weights.stride(0) : 0); + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_cuda", [&] () { + EmbeddingBag_updateOutputKernel<<>>( + indices.data_ptr(), offsets.data_ptr(), + weight.data_ptr(), output.data_ptr(), + offset2bag.data_ptr(), numIndices, numBags, featureSize, + weight.stride(0), weight.stride(1), mode, bag_size.data_ptr(), + mode == MODE_MAX ? max_indices.data_ptr() : NULL, + per_sample_weights.defined() ? per_sample_weights.data_ptr() : NULL, + per_sample_weights.defined() ? per_sample_weights.stride(0) : 0); + }); }); }); @@ -387,12 +395,12 @@ static scalar_t warpReduceSum(scalar_t val) { return val; } -template +template __global__ static void _embedding_bag_per_sample_weights_backward_kernel( const scalar_t* grad, int64_t grad_stride0, int64_t grad_stride1, const scalar_t* weight, int64_t weight_stride0, int64_t weight_stride1, - const int64_t* indices, // contiguous - const int64_t* offset2bag, // contiguous + const index_t* indices, // contiguous + const index_t* offset2bag, // contiguous int64_t num_samples, int64_t embedding_features, scalar_t* output) { @@ -457,16 +465,18 @@ Tensor _embedding_bag_per_sample_weights_backward_cuda( AT_DISPATCH_FLOATING_TYPES_AND_HALF( grad.scalar_type(), "_embedding_bag_per_sample_weights_backward_cuda", [&]() { - _embedding_bag_per_sample_weights_backward_kernel - <<>>( - grad.data_ptr(), grad.stride(0), grad.stride(1), - weight.data_ptr(), weight.stride(0), weight.stride(1), - indices.data_ptr(), - offset2bag.data_ptr(), - num_samples, - embedding_features, - output.data_ptr()); - AT_CUDA_CHECK(cudaGetLastError()); + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "_embedding_bag_per_sample_weights_backward_cuda", [&]() { + _embedding_bag_per_sample_weights_backward_kernel + <<>>( + grad.data_ptr(), grad.stride(0), grad.stride(1), + weight.data_ptr(), weight.stride(0), weight.stride(1), + indices.data_ptr(), + offset2bag.data_ptr(), + num_samples, + embedding_features, + output.data_ptr()); + AT_CUDA_CHECK(cudaGetLastError()); + }); } ); return output; diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index 37ed52755a5..d3c0b8d2955 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -308,10 +308,10 @@ static ptrdiff_t getSliceSize(const Tensor & dst, // the number of indices chosen is large, then the // indexAddLargeIndex kernel is a better choice to increase // parallelism. -template +template __global__ void indexAddSmallIndex(cuda::detail::TensorInfo dst, cuda::detail::TensorInfo src, - cuda::detail::TensorInfo indices, + cuda::detail::TensorInfo indices, int dstAddDim, int srcAddDim, IndexType innerSize, @@ -324,7 +324,7 @@ __global__ void indexAddSmallIndex(cuda::detail::TensorInfo dst, for (IndexType srcIndex = 0; srcIndex < indices.sizes[0]; ++srcIndex) { // Lua indices begin at 1 IndexType dstIndex = - indices.data[cuda::detail::IndexToOffset::get(srcIndex, indices)]; + indices.data[cuda::detail::IndexToOffset::get(srcIndex, indices)]; CUDA_KERNEL_ASSERT(dstIndex < dstAddDimSize); // We stride over the output ignoring the indexed dimension @@ -351,11 +351,11 @@ __global__ void indexAddSmallIndex(cuda::detail::TensorInfo dst, // the number of indices chosen is small, then the // indexAddSmallIndex kernel is a better choice to reduce memory // accesses. -template __global__ void indexAddLargeIndex(cuda::detail::TensorInfo dst, cuda::detail::TensorInfo src, - cuda::detail::TensorInfo indices, + cuda::detail::TensorInfo indices, int dstAddDim, int srcAddDim, IndexType totalSize, @@ -378,7 +378,7 @@ __global__ void indexAddLargeIndex(cuda::detail::TensorInfo dst, // Lua indices begin at 1 IndexType dstIndex = - indices.data[cuda::detail::IndexToOffset::get(srcIndex, indices)]; + indices.data[cuda::detail::IndexToOffset::get(srcIndex, indices)]; CUDA_KERNEL_ASSERT(dstIndex < dstAddDimSize); IndexType dstOffset = @@ -438,7 +438,7 @@ Tensor& index_add_cuda_(Tensor & self, int64_t dim, const Tensor & index, const checkAllSameGPU("index_add", {self_arg, index_arg, source_arg}); TORCH_CHECK_INDEX(index.dim() <= 1, "index_add_(): Index is supposed to be a vector"); - TORCH_CHECK(index.scalar_type() == ScalarType::Long, "index_add_(): Expected dtype int64 for index"); + TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int, "index_add_(): Expected dtype int32/int64 for index"); TORCH_CHECK(self.scalar_type() == source.scalar_type(), "index_add_(): self and source must have the same scalar type"); TORCH_CHECK(dim == 0 || dim < source.dim(), @@ -476,15 +476,15 @@ Tensor& index_add_cuda_(Tensor & self, int64_t dim, const Tensor & index, const int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; -#define SMALL_INDEX(TENSOR_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM) \ - indexAddSmallIndex \ +#define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM) \ + indexAddSmallIndex \ <<>>( \ selfInfo, sourceInfo, indexInfo, \ selfAddDim, sourceAddDim, sliceSize, selfAddDimSize); -#define LARGE_INDEX(TENSOR_TYPE, TYPE, \ +#define LARGE_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, \ SELF_DIM, SOURCE_DIM, IDX_DIM, IDX_IS_MAJOR) \ - indexAddLargeIndex \ <<>>( \ selfInfo, sourceInfo, indexInfo, \ @@ -507,49 +507,50 @@ Tensor& index_add_cuda_(Tensor & self, int64_t dim, const Tensor & index, const cuda::detail::getTensorInfo(self_); int selfAddDim = selfInfo.collapseDims(dim); selfInfo.reduceDim(selfAddDim); + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cuda_", [&] () { + auto sourceInfo = + cuda::detail::getTensorInfo(source_); + int sourceAddDim = sourceInfo.collapseDims(dim); + sourceInfo.reduceDim(sourceAddDim); - auto sourceInfo = - cuda::detail::getTensorInfo(source_); - int sourceAddDim = sourceInfo.collapseDims(dim); - sourceInfo.reduceDim(sourceAddDim); + auto indexInfo = + cuda::detail::getTensorInfo(index); + indexInfo.collapseDims(); - auto indexInfo = - cuda::detail::getTensorInfo(index); - indexInfo.collapseDims(); - - // A reasonable choice for when to have each thread iterate over - // index to choose - if (numIndex <= 16) { - if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) { - SMALL_INDEX(scalar_t, unsigned int, 1, 1, -2); - } else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) { - SMALL_INDEX(scalar_t, unsigned int, 2, 2, -2); - } else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) { - SMALL_INDEX(scalar_t, unsigned int, 3, 3, -2); - } else { - SMALL_INDEX(scalar_t, unsigned int, -1, -1, -1); - } - } else { - bool indexIsMajor = indexShouldBeMajor(selfInfo, selfAddDim); - - if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) { - LARGE_INDEX(scalar_t, unsigned int, 1, 1, -2, true); - } else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) { - if (indexIsMajor) { - LARGE_INDEX(scalar_t, unsigned int, 2, 2, -2, true); + // A reasonable choice for when to have each thread iterate over + // index to choose + if (numIndex <= 16) { + if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) { + SMALL_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2); + } else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) { + SMALL_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2); + } else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) { + SMALL_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2); } else { - LARGE_INDEX(scalar_t, unsigned int, 2, 2, -2, false); - } - } else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) { - if (indexIsMajor) { - LARGE_INDEX(scalar_t, unsigned int, 3, 3, -2, true); - } else { - LARGE_INDEX(scalar_t, unsigned int, 3, 3, -2, false); + SMALL_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1); } } else { - LARGE_INDEX(scalar_t, unsigned int, -1, -1, -1, true); + bool indexIsMajor = indexShouldBeMajor(selfInfo, selfAddDim); + + if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) { + LARGE_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2, true); + } else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) { + if (indexIsMajor) { + LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, true); + } else { + LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, false); + } + } else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) { + if (indexIsMajor) { + LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, true); + } else { + LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, false); + } + } else { + LARGE_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1, true); + } } - } + }); }); }); } else { @@ -565,11 +566,13 @@ Tensor& index_add_cuda_(Tensor & self, int64_t dim, const Tensor & index, const int sourceAddDim = sourceInfo.collapseDims(dim); sourceInfo.reduceDim(sourceAddDim); - cuda::detail::TensorInfo indexInfo = - cuda::detail::getTensorInfo(index); - indexInfo.collapseDims(); + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cuda_", [&] () { + cuda::detail::TensorInfo indexInfo = + cuda::detail::getTensorInfo(index); + indexInfo.collapseDims(); - LARGE_INDEX(scalar_t, uint64_t, -1, -1, -1, true); + LARGE_INDEX(scalar_t, index_t, uint64_t, -1, -1, -1, true); + }); }); }); } @@ -586,10 +589,10 @@ namespace { // the number of indices chosen is large, then the // indexSelectLargeIndex kernel is a better choice to increase // parallelism. -template +template __global__ void indexSelectSmallIndex(cuda::detail::TensorInfo dst, cuda::detail::TensorInfo src, - cuda::detail::TensorInfo indices, + cuda::detail::TensorInfo indices, int dstSelectDim, int srcSelectDim, IndexType innerSize, @@ -601,7 +604,7 @@ __global__ void indexSelectSmallIndex(cuda::detail::TensorInfo dst // re-accessing indices in addition to src elements can be slow. for (IndexType dstIndex = 0; dstIndex < indices.sizes[0]; ++dstIndex) { IndexType srcIndex = - indices.data[cuda::detail::IndexToOffset::get(dstIndex, indices)]; + indices.data[cuda::detail::IndexToOffset::get(dstIndex, indices)]; CUDA_KERNEL_ASSERT(srcIndex < srcSelectDimSize); // We stride over the output ignoring the indexed dimension @@ -628,11 +631,11 @@ __global__ void indexSelectSmallIndex(cuda::detail::TensorInfo dst // the number of indices chosen is small, then the // indexSelectSmallIndex kernel is a better choice to reduce memory // accesses. -template __global__ void indexSelectLargeIndex(cuda::detail::TensorInfo dst, cuda::detail::TensorInfo src, - cuda::detail::TensorInfo indices, + cuda::detail::TensorInfo indices, int dstSelectDim, int srcSelectDim, IndexType totalSize, @@ -654,7 +657,7 @@ __global__ void indexSelectLargeIndex(cuda::detail::TensorInfo dst } IndexType srcIndex = - indices.data[cuda::detail::IndexToOffset::get(dstIndex, indices)]; + indices.data[cuda::detail::IndexToOffset::get(dstIndex, indices)]; CUDA_KERNEL_ASSERT(srcIndex < srcSelectDimSize); IndexType dstOffset = @@ -722,16 +725,16 @@ void index_select_out_cuda_impl(Tensor& out, const Tensor& self, long dim, int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; -#define SMALL_INDEX(TENSOR_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM) \ - indexSelectSmallIndex \ +#define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM) \ + indexSelectSmallIndex \ <<>>( \ outInfo, selfInfo, indicesInfo, \ outSelectDim, selfSelectDim, static_cast(sliceSize), \ selfSelectDimSize); -#define LARGE_INDEX(TENSOR_TYPE, TYPE, \ +#define LARGE_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, \ DST_DIM, SRC_DIM, IDX_DIM, IDX_IS_MAJOR) \ - indexSelectLargeIndex \ <<>>( \ outInfo, selfInfo, indicesInfo, \ @@ -755,42 +758,44 @@ void index_select_out_cuda_impl(Tensor& out, const Tensor& self, long dim, int selfSelectDim = selfInfo.collapseDims(dim); selfInfo.reduceDim(selfSelectDim); - auto indicesInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo(index)); - indicesInfo.collapseDims(); + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_select_out_cuda_impl", [&] () { + auto indicesInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo(index)); + indicesInfo.collapseDims(); - // A reasonable choice for when to have each thread iterate over - // indices to choose - if (numIndices <= 16) { - if (outInfo.dims == 1 && selfInfo.dims == 1 && indContig) { - SMALL_INDEX(scalar_t, unsigned int, 1, 1, -2); - } else if (outInfo.dims == 2 && selfInfo.dims == 2 && indContig) { - SMALL_INDEX(scalar_t, unsigned int, 2, 2, -2); - } else if (outInfo.dims == 3 && selfInfo.dims == 3 && indContig) { - SMALL_INDEX(scalar_t, unsigned int, 3, 3, -2); - } else { - SMALL_INDEX(scalar_t, unsigned int, -1, -1, -1); - } - } else { - bool indexIsMajor = indexShouldBeMajor(outInfo, outSelectDim); - - if (outInfo.dims == 1 && selfInfo.dims == 1 && indContig) { - LARGE_INDEX(scalar_t, unsigned int, 1, 1, -2, true); - } else if (outInfo.dims == 2 && selfInfo.dims == 2 && indContig) { - if (indexIsMajor) { - LARGE_INDEX(scalar_t, unsigned int, 2, 2, -2, true); + // A reasonable choice for when to have each thread iterate over + // indices to choose + if (numIndices <= 16) { + if (outInfo.dims == 1 && selfInfo.dims == 1 && indContig) { + SMALL_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2); + } else if (outInfo.dims == 2 && selfInfo.dims == 2 && indContig) { + SMALL_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2); + } else if (outInfo.dims == 3 && selfInfo.dims == 3 && indContig) { + SMALL_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2); } else { - LARGE_INDEX(scalar_t, unsigned int, 2, 2, -2, false); - } - } else if (outInfo.dims == 3 && selfInfo.dims == 3 && indContig) { - if (indexIsMajor) { - LARGE_INDEX(scalar_t, unsigned int, 3, 3, -2, true); - } else { - LARGE_INDEX(scalar_t, unsigned int, 3, 3, -2, false); + SMALL_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1); } } else { - LARGE_INDEX(scalar_t, unsigned int, -1, -1, -1, true); + bool indexIsMajor = indexShouldBeMajor(outInfo, outSelectDim); + + if (outInfo.dims == 1 && selfInfo.dims == 1 && indContig) { + LARGE_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2, true); + } else if (outInfo.dims == 2 && selfInfo.dims == 2 && indContig) { + if (indexIsMajor) { + LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, true); + } else { + LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, false); + } + } else if (outInfo.dims == 3 && selfInfo.dims == 3 && indContig) { + if (indexIsMajor) { + LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, true); + } else { + LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, false); + } + } else { + LARGE_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1, true); + } } - } + }); } else { auto outInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo(out)); int outSelectDim = outInfo.collapseDims(dim); @@ -799,11 +804,12 @@ void index_select_out_cuda_impl(Tensor& out, const Tensor& self, long dim, auto selfInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo(self)); int selfSelectDim = selfInfo.collapseDims(dim); selfInfo.reduceDim(selfSelectDim); + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_select_out_cuda_impl", [&] () { + auto indicesInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo(index)); + indicesInfo.collapseDims(); - auto indicesInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo(index)); - indicesInfo.collapseDims(); - - LARGE_INDEX(scalar_t, uint64_t, -1, -1, -1, true); + LARGE_INDEX(scalar_t, index_t, uint64_t, -1, -1, -1, true); + }); } #undef SMALL_INDEX #undef LARGE_INDEX diff --git a/caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_idx_avx2.cc b/caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_idx_avx2.cc index bc4b7730dce..b32efc9eae4 100644 --- a/caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_idx_avx2.cc +++ b/caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_idx_avx2.cc @@ -17,7 +17,7 @@ static bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_float_float__avx2_fma( const int64_t data_size, const float* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, bool normalize_by_lengths, float* out) { @@ -401,7 +401,7 @@ bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_float_float_false__avx2_fma( const int64_t data_size, const float* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, bool normalize_by_lengths, float* out) { @@ -425,7 +425,7 @@ bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_float_float_true__avx2_fma( const int64_t data_size, const float* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, bool normalize_by_lengths, float* out) { @@ -883,7 +883,7 @@ static bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_half_float__avx2_fma( const int64_t data_size, const at::Half* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, bool normalize_by_lengths, float* out) { @@ -1387,7 +1387,7 @@ bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_half_float_false__avx2_fma( const int64_t data_size, const at::Half* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, bool normalize_by_lengths, float* out) { @@ -1410,7 +1410,7 @@ bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_half_float_true__avx2_fma( const int64_t data_size, const at::Half* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, bool normalize_by_lengths, float* out) { @@ -1987,7 +1987,7 @@ static bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma( const int64_t data_size, const uint8_t* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, bool normalize_by_lengths, float* out) { @@ -2514,7 +2514,7 @@ bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_uint8_t_float_false__avx2_fma( const int64_t data_size, const uint8_t* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, bool normalize_by_lengths, float* out) { @@ -2538,7 +2538,7 @@ bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_uint8_t_float_true__avx2_fma( const int64_t data_size, const uint8_t* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, bool normalize_by_lengths, float* out) { diff --git a/caffe2/perfkernels/embedding_lookup_idx.cc b/caffe2/perfkernels/embedding_lookup_idx.cc index 1216c6b77cd..d9b2f0627bc 100644 --- a/caffe2/perfkernels/embedding_lookup_idx.cc +++ b/caffe2/perfkernels/embedding_lookup_idx.cc @@ -23,7 +23,7 @@ static bool EmbeddingLookupGenericSlowIdx( const int64_t data_size, const InType* input, const IndexType* indices, - const int64_t* offsets, + const IndexType* offsets, const float* weights, // optional, can be null for sum reducer const float* scale_bias, // optional scale & bias params for uint8 input bool normalize_by_lengths, @@ -85,7 +85,7 @@ static bool EmbeddingLookupGenericSlowIdx( const int64_t data_size, \ const InType* input, \ const IndexType* indices, \ - const int64_t* offsets, \ + const IndexType* offsets, \ const float* weights, \ const float* scale_bias, \ bool normalize_by_lengths, \ @@ -118,7 +118,7 @@ static bool EmbeddingLookupGenericSlowIdx( const int64_t data_size, \ const InType* input, \ const IndexType* indices, \ - const int64_t* offsets, \ + const IndexType* offsets, \ const float* weights, \ const float* scale_bias, \ bool normalize_by_lengths, \ @@ -163,7 +163,7 @@ static bool EmbeddingLookupGenericSlowIdx( const int64_t data_size, \ const InType* input, \ const IndexType* indices, \ - const int64_t* offsets, \ + const IndexType* offsets, \ const float* weights, \ const float* scale_bias, \ bool normalize_by_lengths, \ diff --git a/caffe2/perfkernels/embedding_lookup_idx.h b/caffe2/perfkernels/embedding_lookup_idx.h index 9092b275935..67573fb21fa 100644 --- a/caffe2/perfkernels/embedding_lookup_idx.h +++ b/caffe2/perfkernels/embedding_lookup_idx.h @@ -48,7 +48,7 @@ void EmbeddingLookupIdx( const std::int64_t data_size, const InType* input, const IndexType* indices, - const int64_t* offsets, + const IndexType* offsets, const float* weights, // optional, can be null for non-weighted sum const float* scale_bias, // optional scale & bias params for uint8 input bool normalize_by_lengths, diff --git a/caffe2/perfkernels/embedding_lookup_idx_avx2.cc b/caffe2/perfkernels/embedding_lookup_idx_avx2.cc index eb61353866c..329598b84d4 100644 --- a/caffe2/perfkernels/embedding_lookup_idx_avx2.cc +++ b/caffe2/perfkernels/embedding_lookup_idx_avx2.cc @@ -17,7 +17,7 @@ static bool EmbeddingLookupIdx_int32_t_float_float__avx2_fma( const int64_t data_size, const float* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, const float* scale_bias, bool normalize_by_lengths, @@ -402,7 +402,7 @@ bool EmbeddingLookupIdx_int32_t_float_float_false__avx2_fma( const int64_t data_size, const float* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, const float* scale_bias, bool normalize_by_lengths, @@ -427,7 +427,7 @@ bool EmbeddingLookupIdx_int32_t_float_float_true__avx2_fma( const int64_t data_size, const float* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, const float* scale_bias, bool normalize_by_lengths, @@ -891,7 +891,7 @@ static bool EmbeddingLookupIdx_int32_t_half_float__avx2_fma( const int64_t data_size, const at::Half* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, const float* scale_bias, bool normalize_by_lengths, @@ -1396,7 +1396,7 @@ bool EmbeddingLookupIdx_int32_t_half_float_false__avx2_fma( const int64_t data_size, const at::Half* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, const float* scale_bias, bool normalize_by_lengths, @@ -1421,7 +1421,7 @@ bool EmbeddingLookupIdx_int32_t_half_float_true__avx2_fma( const int64_t data_size, const at::Half* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, const float* scale_bias, bool normalize_by_lengths, @@ -2005,7 +2005,7 @@ static bool EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma( const int64_t data_size, const uint8_t* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, const float* scale_bias, bool normalize_by_lengths, @@ -2523,7 +2523,7 @@ bool EmbeddingLookupIdx_int32_t_uint8_t_float_false__avx2_fma( const int64_t data_size, const uint8_t* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, const float* scale_bias, bool normalize_by_lengths, @@ -2548,7 +2548,7 @@ bool EmbeddingLookupIdx_int32_t_uint8_t_float_true__avx2_fma( const int64_t data_size, const uint8_t* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, const float* scale_bias, bool normalize_by_lengths, diff --git a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.cc b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.cc index 28972e4f49a..99a41d742d1 100644 --- a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.cc +++ b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.cc @@ -22,7 +22,7 @@ static bool Fused8BitRowwiseEmbeddingLookupGenericSlowIdx( const int64_t data_size, const InType* input, const IndexType* indices, - const int64_t* offsets, + const IndexType* offsets, const float* weights, // optional, can be null for sum reducer bool normalize_by_lengths, OutType* out) { @@ -88,7 +88,7 @@ static bool Fused8BitRowwiseEmbeddingLookupGenericSlowIdx( const int64_t data_size, \ const uint8_t* input, \ const IndexType* indices, \ - const int64_t* offsets, \ + const IndexType* offsets, \ const float* weights, \ bool normalize_by_lengths, \ OutType* out) { \ @@ -118,7 +118,7 @@ static bool Fused8BitRowwiseEmbeddingLookupGenericSlowIdx( const int64_t data_size, \ const uint8_t* input, \ const IndexType* indices, \ - const int64_t* offsets, \ + const IndexType* offsets, \ const float* weights, \ bool normalize_by_lengths, \ OutType* out) { \ @@ -160,7 +160,7 @@ static bool Fused8BitRowwiseEmbeddingLookupGenericSlowIdx( const int64_t data_size, \ const uint8_t* input, \ const IndexType* indices, \ - const int64_t* offsets, \ + const IndexType* offsets, \ const float* weights, \ bool normalize_by_lengths, \ OutType* out) { \ diff --git a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.h b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.h index 9970c8671d0..f7422bd7b75 100644 --- a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.h +++ b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.h @@ -50,7 +50,7 @@ void Fused8BitRowwiseEmbeddingLookupIdx( const std::int64_t data_size, const InType* input, const IndexType* indices, - const int64_t* offsets, + const IndexType* offsets, const float* weights, // optional, can be null for non-weighted sum bool normalize_by_lengths, OutType* out); diff --git a/caffe2/perfkernels/hp_emblookup_codegen.py b/caffe2/perfkernels/hp_emblookup_codegen.py index 75b0c8b583b..402f3bb92a4 100644 --- a/caffe2/perfkernels/hp_emblookup_codegen.py +++ b/caffe2/perfkernels/hp_emblookup_codegen.py @@ -450,7 +450,7 @@ for o in options: args.append(" const " + InType + "* input,") args.append(" const " + IndexType + "* indices,") if opts.use_offsets: - args.append(" const int64_t* offsets,") + args.append(" const " + IndexType + "* offsets,") else: args.append(" const int* lengths,") args.append(" const float* weights,") diff --git a/test/test_nn.py b/test/test_nn.py index 2ce752aa0eb..6b8c97db2f5 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -11610,25 +11610,27 @@ class TestNNDeviceType(NNTestCase): self.assertRaises(RuntimeError, lambda: F.adaptive_max_pool2d(t, [])) self.assertRaises(RuntimeError, lambda: F.adaptive_max_pool3d(t, [])) - def test_embedding_bag_empty_input(self, device): + @dtypes(torch.int, torch.long) + def test_embedding_bag_empty_input(self, device, dtype): m = 4 n = 3 - x = torch.tensor([], device=device, dtype=torch.long) + x = torch.tensor([], device=device, dtype=dtype) for sparse in [True, False]: Embed = torch.nn.EmbeddingBag(m, n, sparse=sparse) Embed.to(device) - output = Embed(input=x, offsets=torch.tensor([0], device=device, dtype=torch.long)) + output = Embed(input=x, offsets=torch.tensor([0], device=device, dtype=dtype)) self.assertEqual(output, torch.zeros_like(output)) - output = Embed(input=x, offsets=torch.tensor([0, 0], device=device, dtype=torch.long)) + output = Embed(input=x, offsets=torch.tensor([0, 0], device=device, dtype=dtype)) self.assertEqual(output, torch.zeros_like(output)) - def test_EmbeddingBag_per_sample_weights_failures(self, device): + @dtypes(torch.int, torch.long) + def test_EmbeddingBag_per_sample_weights_failures(self, device, dtype): # Failure 1: mismatched embeddings / per_sample_weights dtype es = nn.EmbeddingBag(5, 2, mode='sum').to(dtype=torch.float, device=device) - input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=torch.long, device=device) - offsets = torch.tensor([0, 0, 3, 3, 6], dtype=torch.long, device=device) + input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=dtype, device=device) + offsets = torch.tensor([0, 0, 3, 3, 6], dtype=dtype, device=device) per_sample_weights = torch.randn_like(input, dtype=torch.double, device=device) if device == 'cpu': with self.assertRaisesRegex(RuntimeError, 'have the same type as'): @@ -11638,14 +11640,14 @@ class TestNNDeviceType(NNTestCase): es(input, offsets, per_sample_weights) # Failure 2.1: input/per_sample_weights have different sizes (1d input) - input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=torch.long, device=device) - offsets = torch.tensor([0, 0, 3, 3, 6], dtype=torch.long, device=device) + input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=dtype, device=device) + offsets = torch.tensor([0, 0, 3, 3, 6], dtype=dtype, device=device) per_sample_weights = torch.randn(5, dtype=torch.float, device=device) with self.assertRaisesRegex(ValueError, 'same shape as the input'): es(input, offsets, per_sample_weights) # Failure 2.2: input/per_sample_weights have different sizes (2d input) - input = torch.randint(5, (7, 3), dtype=torch.long, device=device) + input = torch.randint(5, (7, 3), dtype=dtype, device=device) offsets = None per_sample_weights = torch.randn(7 * 3, dtype=torch.float, device=device) with self.assertRaisesRegex(ValueError, 'same shape as the input'): @@ -11655,7 +11657,7 @@ class TestNNDeviceType(NNTestCase): for unsupported_mode in ('max', 'mean'): es = nn.EmbeddingBag(5, 2, mode=unsupported_mode).to( dtype=torch.float, device=device) - input = torch.randint(5, (7, 3), dtype=torch.long, device=device) + input = torch.randint(5, (7, 3), dtype=dtype, device=device) offsets = None per_sample_weights = torch.randn(7, 3, dtype=torch.float, device=device) with self.assertRaisesRegex(NotImplementedError, @@ -11673,7 +11675,8 @@ class TestNNDeviceType(NNTestCase): assert input.numel() == per_sample_weights.numel() bags = [] - embeddings = weight.index_select(0, input) * per_sample_weights.unsqueeze(1) + long_input = input.to(torch.long) + embeddings = weight.index_select(0, long_input) * per_sample_weights.unsqueeze(1) if include_last_offset: for index in range(len(offsets) - 1): offset = offsets[index] @@ -11698,7 +11701,7 @@ class TestNNDeviceType(NNTestCase): if index + 1 < len(offsets): next_offset = offsets[index + 1] else: - next_offset = len(input) + next_offset = len(long_input) length = next_offset - offset if length == 0: bags.append( @@ -11716,16 +11719,18 @@ class TestNNDeviceType(NNTestCase): bags.append(embeddings.narrow(0, offset, length).max(0)[0]) return torch.stack(bags) - def test_EmbeddingBag_empty_per_sample_weights_and_offsets(self, device): + @dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.float, torch.double, torch.half))) + @dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double))) + def test_EmbeddingBag_empty_per_sample_weights_and_offsets(self, device, dtypes): # Test empty input and per sample weight, and backward pass. There was a CUDA # invalid configuration bug (more context in #46572) - def test_per_sample_weights(mode, dtype, trainable_scale): - es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtype, device=device) + def test_per_sample_weights(mode, trainable_scale): + es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtypes[1], device=device) es.weight.data.copy_( - torch.arange(1, 11, device=device, dtype=dtype).view_as(es.weight)) - input = torch.tensor([], device=device, dtype=torch.long) - offsets = torch.tensor([0, 0, 0, 0, 0], device=device, dtype=torch.long) - per_sample_weights = torch.randn_like(input, dtype=dtype) \ + torch.arange(1, 11, device=device, dtype=dtypes[1]).view_as(es.weight)) + input = torch.tensor([], device=device, dtype=dtypes[0]) + offsets = torch.tensor([0, 0, 0, 0, 0], device=device, dtype=dtypes[0]) + per_sample_weights = torch.randn_like(input, dtype=dtypes[1]) \ .requires_grad_(trainable_scale) ref_per_sample_weights = \ per_sample_weights.detach().requires_grad_(trainable_scale) @@ -11734,7 +11739,7 @@ class TestNNDeviceType(NNTestCase): expected = self._embedding_bag_reference_impl( input, reference_weights, offsets, mode, ref_per_sample_weights) result = es(input, offsets, per_sample_weights) - self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtype], rtol=0) + self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0) grad = torch.randn_like(expected) result.backward(grad) @@ -11742,29 +11747,27 @@ class TestNNDeviceType(NNTestCase): # simply be a zero tensor ref_weights_grad = torch.zeros_like(es.weight) self.assertEqual(es.weight.grad, ref_weights_grad, - atol=dtype2prec_DONTUSE[dtype], rtol=0) + atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0) if trainable_scale: ref_per_sample_weights_grad = torch.empty_like(per_sample_weights) self.assertEqual(per_sample_weights.grad, ref_per_sample_weights_grad, - atol=dtype2prec_DONTUSE[dtype], rtol=0) + atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0) - if device == 'cuda': - dtypes = (torch.float, torch.double, torch.half) - else: - dtypes = (torch.float, torch.double) modes = ('sum',) trainable_scale = (True, False) - for dtype, mode, trainable in itertools.product(dtypes, modes, trainable_scale): - test_per_sample_weights(mode, dtype, trainable) + for mode, trainable in itertools.product(modes, trainable_scale): + test_per_sample_weights(mode, trainable) - def test_EmbeddingBag_per_sample_weights_and_offsets(self, device): - def test_per_sample_weights(mode, dtype, trainable_scale): - es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtype, device=device) + @dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.float, torch.double, torch.half))) + @dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double))) + def test_EmbeddingBag_per_sample_weights_and_offsets(self, device, dtypes): + def test_per_sample_weights(mode, trainable_scale): + es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtypes[1], device=device) es.weight.data.copy_( - torch.arange(1, 11, device=device, dtype=dtype).view_as(es.weight)) - input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=torch.long) - offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=torch.long) - per_sample_weights = torch.randn_like(input, dtype=dtype) \ + torch.arange(1, 11, device=device, dtype=dtypes[1]).view_as(es.weight)) + input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtypes[0]) + offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=dtypes[0]) + per_sample_weights = torch.randn_like(input, dtype=dtypes[1]) \ .requires_grad_(trainable_scale) ref_per_sample_weights = \ per_sample_weights.detach().requires_grad_(trainable_scale) @@ -11773,39 +11776,37 @@ class TestNNDeviceType(NNTestCase): expected = self._embedding_bag_reference_impl( input, reference_weights, offsets, mode, ref_per_sample_weights) result = es(input, offsets, per_sample_weights) - self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtype], rtol=0) + self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0) grad = torch.randn_like(expected) result.backward(grad) expected.backward(grad) self.assertEqual(es.weight.grad, reference_weights.grad, - atol=dtype2prec_DONTUSE[dtype], rtol=0) + atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0) if trainable_scale: self.assertEqual(per_sample_weights.grad, ref_per_sample_weights.grad, - atol=dtype2prec_DONTUSE[dtype], rtol=0) + atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0) - if device == 'cuda': - dtypes = (torch.float, torch.double, torch.half) - else: - dtypes = (torch.float, torch.double) modes = ('sum',) trainable_scale = (True, False) - for dtype, mode, trainable in itertools.product(dtypes, modes, trainable_scale): - test_per_sample_weights(mode, dtype, trainable) + for mode, trainable in itertools.product(modes, trainable_scale): + test_per_sample_weights(mode, trainable) - def test_EmbeddingBag_per_sample_weights_and_new_offsets(self, device): - def test_per_sample_weights_new_offsets(mode, dtype, trainable_scale, include_last_offset, has_weight=True): - es = nn.EmbeddingBag(5, 2, mode=mode, include_last_offset=include_last_offset).to(dtype=dtype, device=device) + @dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.float, torch.double, torch.half))) + @dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double))) + def test_EmbeddingBag_per_sample_weights_and_new_offsets(self, device, dtypes): + def test_per_sample_weights_new_offsets(mode, trainable_scale, include_last_offset, has_weight=True): + es = nn.EmbeddingBag(5, 2, mode=mode, include_last_offset=include_last_offset).to(dtype=dtypes[1], device=device) es.weight.data.copy_( - torch.arange(1, 11, device=device, dtype=dtype).view_as(es.weight)) - input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=torch.long) - offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=torch.long) + torch.arange(1, 11, device=device, dtype=dtypes[1]).view_as(es.weight)) + input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtypes[0]) + offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=dtypes[0]) if include_last_offset: - offsets = torch.cat((offsets, torch.tensor([input.size(0)], device=device, dtype=torch.long)), 0) + offsets = torch.cat((offsets, torch.tensor([input.size(0)], device=device, dtype=dtypes[0])), 0) if has_weight: - per_sample_weights = torch.randn_like(input, device=device, dtype=dtype) \ + per_sample_weights = torch.randn_like(input, device=device, dtype=dtypes[1]) \ .requires_grad_(trainable_scale) ref_per_sample_weights = \ per_sample_weights.detach().requires_grad_(trainable_scale) @@ -11818,51 +11819,48 @@ class TestNNDeviceType(NNTestCase): expected = self._embedding_bag_reference_impl( input, reference_weights, offsets, mode, ref_per_sample_weights, include_last_offset) result = es(input, offsets, per_sample_weights) - self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtype], rtol=0) + self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0) grad = torch.randn_like(expected) result.backward(grad) expected.backward(grad) self.assertEqual(es.weight.grad, reference_weights.grad, - atol=dtype2prec_DONTUSE[dtype], rtol=0) + atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0) if has_weight and trainable_scale: self.assertEqual(per_sample_weights.grad, ref_per_sample_weights.grad, - atol=dtype2prec_DONTUSE[dtype], rtol=0) + atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0) - if device == 'cuda': - dtypes = (torch.float, torch.double, torch.half) - else: - dtypes = (torch.float, torch.double) trainable_scale = (True, False) include_last_offset = (True, False) modes = (('sum', False), ('sum', True), ('max', False), ('mean', False)) - for dtype, (mode, has_weight), trainable, include_last_offset in itertools.product( - dtypes, modes, trainable_scale, include_last_offset + for (mode, has_weight), trainable, include_last_offset in itertools.product( + modes, trainable_scale, include_last_offset ): test_per_sample_weights_new_offsets( - mode, dtype, trainable, include_last_offset, has_weight + mode, trainable, include_last_offset, has_weight ) def _test_EmbeddingBag_vs_Embedding(self, N, D, B, L, max_norm=None, mode='mean', device='cpu', - dtype=torch.float, + wdtype=torch.float, + dtype=torch.long, test_per_sample_weights=False, trainable_per_sample_weights=False, sparse=False, test_backward=True, backward_prec=None): - es = nn.EmbeddingBag(N, D, mode=mode, sparse=sparse, max_norm=max_norm).to(device, dtype) - e = nn.Embedding(N, D, max_norm=max_norm).to(device, dtype) + es = nn.EmbeddingBag(N, D, mode=mode, sparse=sparse, max_norm=max_norm).to(device, wdtype) + e = nn.Embedding(N, D, max_norm=max_norm).to(device, wdtype) e.weight.data.copy_(es.weight) - input = torch.randint(N, (B, L), device=device, dtype=torch.long) - offsets = torch.arange(0, B, device=device, dtype=torch.long).mul_(L) - grad_output = torch.rand(B, D, device=device, dtype=dtype) + input = torch.randint(N, (B, L), device=device, dtype=dtype) + offsets = torch.arange(0, B, device=device, dtype=dtype).mul_(L) + grad_output = torch.rand(B, D, device=device, dtype=wdtype) if test_per_sample_weights: # To prevent large gradients, weights should sum to 1 for each bag per_sample_weights = \ - torch.randn(B, L, device=device, dtype=dtype).softmax(dim=-1) + torch.randn(B, L, device=device, dtype=wdtype).softmax(dim=-1) per_sample_weights_reference = \ per_sample_weights.clone().requires_grad_(trainable_per_sample_weights) per_sample_weights.requires_grad_(trainable_per_sample_weights) @@ -11884,7 +11882,7 @@ class TestNNDeviceType(NNTestCase): assert not test_per_sample_weights ref_output = e(input).max(1)[0] - self.assertEqual(output, ref_output, atol=dtype2prec_DONTUSE[dtype], rtol=0) + self.assertEqual(output, ref_output, atol=dtype2prec_DONTUSE[wdtype], rtol=0) if not test_backward: return @@ -11897,7 +11895,7 @@ class TestNNDeviceType(NNTestCase): # We have more floating point error here because we are dealing with larger numbers if backward_prec is None: - needed_prec = dtype2prec_DONTUSE[dtype] * 3 + needed_prec = dtype2prec_DONTUSE[wdtype] * 3 else: needed_prec = backward_prec @@ -11905,13 +11903,15 @@ class TestNNDeviceType(NNTestCase): if test_per_sample_weights and trainable_per_sample_weights: self.assertEqual(per_sample_weights.grad, per_sample_weights_reference.grad, - atol=dtype2prec_DONTUSE[dtype], rtol=0) + atol=dtype2prec_DONTUSE[wdtype], rtol=0) @skipCUDAIf(True, "Temporarily disabled. See t54369166") - def test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device): - def run_tests(dtype, mode, sparse, trainable_per_sample_weights): + @dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.half, torch.float, torch.double))) + @dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double))) + def test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device, dtypes): + def run_tests(mode, sparse, trainable_per_sample_weights): kwargs = dict(test_per_sample_weights=True, device=device, - mode=mode, dtype=dtype, sparse=sparse, + mode=mode, wdtype=dtypes[1], dtype=dtypes[0], sparse=sparse, trainable_per_sample_weights=trainable_per_sample_weights) # Simple case @@ -11926,78 +11926,76 @@ class TestNNDeviceType(NNTestCase): # Large embedding_dim self._test_EmbeddingBag_vs_Embedding(2, 101, 3, 7, **kwargs) - dtypes = (torch.float, torch.double) modes = ('sum',) sparsity = (True, False) trainable_scale = (True, False) - for dtype, mode, sparse, trainable_per_sample_weights in \ - itertools.product(dtypes, modes, sparsity, trainable_scale): - run_tests(dtype, mode, sparse, trainable_per_sample_weights) + for mode, sparse, trainable_per_sample_weights in \ + itertools.product(modes, sparsity, trainable_scale): + run_tests(mode, sparse, trainable_per_sample_weights) # Test CUDA Dense on half precision if device == 'cuda': - dtypes = (torch.half,) modes = ('sum',) sparsity = (False,) trainable_scale = (True, False) - for dtype, mode, sparse, trainable_per_sample_weights in \ - itertools.product(dtypes, modes, sparsity, trainable_scale): - run_tests(dtype, mode, sparse, trainable_per_sample_weights) + for mode, sparse, trainable_per_sample_weights in \ + itertools.product(modes, sparsity, trainable_scale): + run_tests(mode, sparse, trainable_per_sample_weights) - def _test_EmbeddingBag(self, device, mode, sparse, dtype=torch.double, test_backward=True): + def _test_EmbeddingBag(self, device, mode, sparse, wdtype=torch.double, dtype=torch.long, test_backward=True): # check a known test example - es = nn.EmbeddingBag(5, 2, mode=mode, sparse=sparse).to(device, dtype) - es.weight.data.copy_(torch.arange(1, 11, device=device, dtype=dtype).view_as(es.weight)) - input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=torch.long) - offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=torch.long) + es = nn.EmbeddingBag(5, 2, mode=mode, sparse=sparse).to(device, wdtype) + es.weight.data.copy_(torch.arange(1, 11, device=device, dtype=wdtype).view_as(es.weight)) + input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtype) + offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=dtype) grad_output = torch.tensor( [1, 2, - 3, 4], device=device, dtype=dtype).view(2, 2) + 3, 4], device=device, dtype=wdtype).view(2, 2) grad_output_with_empty = torch.tensor( [99, 99, 1, 2, 99, 99, 3, 4, - 99, 99], device=device, dtype=dtype).view(5, 2) + 99, 99], device=device, dtype=wdtype).view(5, 2) if mode == "sum" or mode == "mean": denominator = 1 if mode == "sum" else 3 expected_output = torch.tensor( [[13, 16], - [13, 16]], device=device, dtype=dtype) / denominator + [13, 16]], device=device, dtype=wdtype) / denominator expected_output_with_empty = torch.tensor( [[0, 0], [13, 16], [0, 0], [13, 16], - [0, 0]], device=device, dtype=dtype) / denominator + [0, 0]], device=device, dtype=wdtype) / denominator expected_grad_weight = torch.tensor( [[3, 4], [5, 8], [0, 0], [1, 2], - [3, 4]], device=device, dtype=dtype) / denominator + [3, 4]], device=device, dtype=wdtype) / denominator elif mode == "max": expected_output = torch.tensor( [[7, 8], - [9, 10]], device=device, dtype=dtype) + [9, 10]], device=device, dtype=wdtype) expected_output_with_empty = torch.tensor( [[0, 0], [7, 8], [0, 0], [9, 10], - [0, 0]], device=device, dtype=dtype) + [0, 0]], device=device, dtype=wdtype) expected_grad_weight = torch.tensor( [[0, 0], [0, 0], [0, 0], [1, 2], - [3, 4]], device=device, dtype=dtype) + [3, 4]], device=device, dtype=wdtype) output = es(input, offsets) output.backward(grad_output_with_empty) @@ -12005,7 +12003,7 @@ class TestNNDeviceType(NNTestCase): if sparse: es_weight_grad = es.weight.grad.to_dense() self.assertEqual(output, expected_output_with_empty) - self.assertEqual(es_weight_grad, expected_grad_weight, atol=dtype2prec_DONTUSE[dtype], rtol=0) + self.assertEqual(es_weight_grad, expected_grad_weight, atol=dtype2prec_DONTUSE[wdtype], rtol=0) # check same example except as 2D (2 x 3) input = input.view(2, -1) @@ -12017,12 +12015,12 @@ class TestNNDeviceType(NNTestCase): if sparse: es_weight_grad = es.weight.grad.to_dense() self.assertEqual(output, expected_output) - self.assertEqual(es_weight_grad, expected_grad_weight, atol=dtype2prec_DONTUSE[dtype], rtol=0) + self.assertEqual(es_weight_grad, expected_grad_weight, atol=dtype2prec_DONTUSE[wdtype], rtol=0) # test all empty bags es.zero_grad() - inputs = torch.tensor([], dtype=torch.long, device=device) - offsets = torch.tensor([0, 0, 0, 0], device=device) + inputs = torch.tensor([], dtype=dtype, device=device) + offsets = torch.tensor([0, 0, 0, 0], dtype=dtype, device=device) es(inputs, offsets).sum().backward() dense_grad = es.weight.grad if dense_grad.is_sparse: @@ -12031,7 +12029,7 @@ class TestNNDeviceType(NNTestCase): # now compare EmbeddingBag vs Embedding + Sum/Mean, for constant bag length N, D, B, L = random.randint(1, 100), random.randint(1, 100), random.randint(1, 50), random.randint(1, 50) - kwargs = dict(mode=mode, sparse=sparse, device=device, dtype=dtype, test_backward=test_backward) + kwargs = dict(mode=mode, sparse=sparse, device=device, wdtype=wdtype, dtype=dtype, test_backward=test_backward) self._test_EmbeddingBag_vs_Embedding(N, D, B, L, **kwargs) for max_norm in (None, 3): for p in itertools.product([1, 2], repeat=4): @@ -12039,8 +12037,8 @@ class TestNNDeviceType(NNTestCase): # check that giving illegal input combos raises error es = nn.EmbeddingBag(10, 20, mode=mode, sparse=sparse) - input = torch.ones(3, 4, dtype=torch.long) - offset = torch.arange(0, 3) + input = torch.ones(3, 4, dtype=dtype) + offset = torch.arange(0, 3, dtype=dtype) self.assertRaises(ValueError, lambda: es(input, offset)) self.assertRaises(ValueError, lambda: es(input.view(-1))) offset[0] = 1 @@ -12050,35 +12048,35 @@ class TestNNDeviceType(NNTestCase): offset[-1] = 100 self.assertRaises(RuntimeError, lambda: es(input.view(-1), offset)) - @dtypesIfCUDA(torch.half, torch.float, torch.double) - @dtypes(torch.float, torch.double) - def test_embedding_bag_device(self, device, dtype): - self._test_EmbeddingBag(device, 'sum', False, dtype) - self._test_EmbeddingBag(device, 'mean', False, dtype) - self._test_EmbeddingBag(device, 'max', False, dtype) + @dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.float, torch.double, torch.half))) + @dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double))) + def test_embedding_bag_device(self, device, dtypes): + self._test_EmbeddingBag(device, 'sum', False, wdtype=dtypes[1], dtype=dtypes[0]) + self._test_EmbeddingBag(device, 'mean', False, wdtype=dtypes[1], dtype=dtypes[0]) + self._test_EmbeddingBag(device, 'max', False, wdtype=dtypes[1], dtype=dtypes[0]) test_backward = False if self.device_type == 'cuda': # see 'todo' in test_embedding_bag. - test_backward = dtype is not torch.float16 + test_backward = dtypes[1] is not torch.float16 elif self.device_type == 'cpu': # TODO: figure out why precision on sparse embeddings isn't the # same as for dense. - test_backward = dtype is not torch.float + test_backward = dtypes[1] is not torch.float - self._test_EmbeddingBag(device, 'sum', True, dtype, test_backward=test_backward) - self._test_EmbeddingBag(device, 'mean', True, dtype, test_backward=test_backward) + self._test_EmbeddingBag(device, 'sum', True, wdtype=dtypes[1], dtype=dtypes[0], test_backward=test_backward) + self._test_EmbeddingBag(device, 'mean', True, wdtype=dtypes[1], dtype=dtypes[0], test_backward=test_backward) - @dtypesIfCUDA(torch.half, torch.float, torch.double) - @dtypes(torch.float, torch.double) - def test_embedding_bag_non_contiguous_weight(self, device, dtype): - weight_tensor = torch.randn(3, 4, dtype=dtype, device=device) + @dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.float, torch.double, torch.half))) + @dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double))) + def test_embedding_bag_non_contiguous_weight(self, device, dtypes): + weight_tensor = torch.randn(3, 4, dtype=dtypes[1], device=device) weight_tensor_non_contig = weight_tensor[:, :3] # This is non-contiguous strided. weight_tensor_contig = weight_tensor_non_contig.clone().contiguous() # Contig-strided. - index = torch.tensor([0, 1, 2], device=device) - offsets = torch.tensor([0, 2], device=device) + index = torch.tensor([0, 1, 2], dtype=dtypes[0], device=device) + offsets = torch.tensor([0, 2], dtype=dtypes[0], device=device) for mode in ['sum', 'mean', 'max']: output_non_contig = F.embedding_bag( input=index, @@ -12097,9 +12095,10 @@ class TestNNDeviceType(NNTestCase): @onlyCUDA @skipCUDAIfNotRocm - def test_embedding_bag_bfloat16(self, device): - self._test_EmbeddingBag(device, 'sum', True, dtype=torch.bfloat16, test_backward=True) - self._test_EmbeddingBag(device, 'mean', True, dtype=torch.bfloat16, test_backward=True) + @dtypes(torch.int, torch.long) + def test_embedding_bag_bfloat16(self, device, dtype): + self._test_EmbeddingBag(device, 'sum', True, wdtype=torch.bfloat16, dtype=dtype, test_backward=True) + self._test_EmbeddingBag(device, 'mean', True, wdtype=torch.bfloat16, dtype=dtype, test_backward=True) @onlyCUDA diff --git a/test/test_torch.py b/test/test_torch.py index f1f22be7553..e3096c0c13e 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -1618,23 +1618,25 @@ class AbstractTestCases: reference[0.0, :, 0.0] = 1 def test_index_add(self): - for dest_contig, src_contig, index_contig in product([True, False], repeat=3): - for other_sizes in ((), (4, 5)): - num_copy, num_dest = 3, 3 - dest = torch.randn(num_dest, *other_sizes) - if not dest_contig: - dest = torch.testing.make_non_contiguous(dest) - src = torch.randn(num_copy, *other_sizes) - if not src_contig: - src = torch.testing.make_non_contiguous(src) - idx = torch.randperm(num_dest).narrow(0, 0, num_copy) - if not index_contig: - idx = torch.testing.make_non_contiguous(idx) - dest2 = dest.clone() - dest.index_add_(0, idx, src) - for i in range(idx.size(0)): - dest2[idx[i]] += src[i] - self.assertEqual(dest, dest2) + for device in torch.testing.get_all_device_types(): + for dest_contig, src_contig, index_contig in product([True, False], repeat=3): + for other_sizes in ((), (4, 5)): + for dtype in [torch.int, torch.long]: + num_copy, num_dest = 3, 3 + dest = torch.randn(num_dest, *other_sizes, device=device) + if not dest_contig: + dest = torch.testing.make_non_contiguous(dest) + src = torch.randn(num_copy, *other_sizes, device=device) + if not src_contig: + src = torch.testing.make_non_contiguous(src) + idx = torch.randperm(num_dest, dtype=dtype, device=device).narrow(0, 0, num_copy) + if not index_contig: + idx = torch.testing.make_non_contiguous(idx) + dest2 = dest.clone() + dest.index_add_(0, idx, src) + for i in range(idx.size(0)): + dest2[idx[i]] += src[i] + self.assertEqual(dest, dest2) # add coverage for issue with atomic add that appeared only for # specific dtypes on cuda: @@ -1642,23 +1644,24 @@ class AbstractTestCases: def test_index_add_all_dtypes(self): for device in torch.testing.get_all_device_types(): for dtype in torch.testing.get_all_math_dtypes(device): - size = [5, 5] - if dtype.is_floating_point or dtype.is_complex: - tensor = torch.rand(size, dtype=dtype, device=device) - elif dtype.is_signed: - tensor = torch.randint(-5, 15, size, dtype=dtype, device=device) - else: - tensor = torch.randint(0, 10, size, dtype=dtype, device=device) + for idx_dtype in [torch.int, torch.long]: + size = [5, 5] + if dtype.is_floating_point or dtype.is_complex: + tensor = torch.rand(size, dtype=dtype, device=device) + elif dtype.is_signed: + tensor = torch.randint(-5, 15, size, dtype=dtype, device=device) + else: + tensor = torch.randint(0, 10, size, dtype=dtype, device=device) - # index_add calls atomicAdd on cuda. - zeros = torch.zeros(size, dtype=dtype, device=device) + # index_add calls atomicAdd on cuda. + zeros = torch.zeros(size, dtype=dtype, device=device) - # index_add is not supported for complex dtypes on cuda yet - if device.startswith('cuda') and dtype.is_complex: - continue + # index_add is not supported for complex dtypes on cuda yet + if device.startswith('cuda') and dtype.is_complex: + continue - added = zeros.index_add(0, torch.arange(0, size[0], dtype=torch.long, device=device), tensor) - self.assertEqual(added, tensor) + added = zeros.index_add(0, torch.arange(0, size[0], dtype=idx_dtype, device=device), tensor) + self.assertEqual(added, tensor) def test_t(self): # Test 0D tensors @@ -12735,36 +12738,37 @@ class TestTorchDeviceType(TestCase): self.assertEqual(x, torch.tensor([[0, 2], [0, 5]], dtype=dt, device=device)) def test_index_select(self, device): - src = torch.randn(3, 4, 5, device=device) - # Index can be duplicated. - idx = torch.tensor([2, 1, 0, 1, 2], dtype=torch.long, device=device) - dest = torch.index_select(src, 0, idx) - self.assertEqual(dest.shape, (5, 4, 5)) - for i in range(idx.size(0)): - self.assertEqual(dest[i], src[idx[i]]) + for dtype in [torch.int, torch.long]: + src = torch.randn(3, 4, 5, device=device) + # Index can be duplicated. + idx = torch.tensor([2, 1, 0, 1, 2], dtype=dtype, device=device) + dest = torch.index_select(src, 0, idx) + self.assertEqual(dest.shape, (5, 4, 5)) + for i in range(idx.size(0)): + self.assertEqual(dest[i], src[idx[i]]) - # Check that 'out' is used correctly. - out = torch.randn(5 * 4 * 5, device=device) - dest = torch.index_select(src, 0, idx, out=out.view(5, 4, 5)) - self.assertEqual(dest.shape, (5, 4, 5)) - for i in range(idx.size(0)): - self.assertEqual(dest[i], src[idx[i]]) - out.fill_(0.123) - self.assertEqual(out, dest.view(-1)) # Must point to the same storage. + # Check that 'out' is used correctly. + out = torch.randn(5 * 4 * 5, device=device) + dest = torch.index_select(src, 0, idx, out=out.view(5, 4, 5)) + self.assertEqual(dest.shape, (5, 4, 5)) + for i in range(idx.size(0)): + self.assertEqual(dest[i], src[idx[i]]) + out.fill_(0.123) + self.assertEqual(out, dest.view(-1)) # Must point to the same storage. - # Bool tensor - src = torch.tensor([False, True, False, False], device=device, dtype=torch.bool) - idx = torch.tensor([1], dtype=torch.long, device=device) - dest = torch.index_select(src, 0, idx) - self.assertEqual(torch.tensor([True]), dest) + # Bool tensor + src = torch.tensor([False, True, False, False], device=device, dtype=torch.bool) + idx = torch.tensor([1], dtype=dtype, device=device) + dest = torch.index_select(src, 0, idx) + self.assertEqual(torch.tensor([True]), dest) - # Complex Tensor - src = torch.randn(3, 4, 5, dtype=torch.complex64, device=device) - idx = torch.tensor([2, 1, 0, 1, 2], dtype=torch.long, device=device) - dest = torch.index_select(src, 0, idx) - self.assertEqual(dest.shape, (5, 4, 5)) - for i in range(idx.size(0)): - self.assertEqual(dest[i], src[idx[i]]) + # Complex Tensor + src = torch.randn(3, 4, 5, dtype=torch.complex64, device=device) + idx = torch.tensor([2, 1, 0, 1, 2], dtype=dtype, device=device) + dest = torch.index_select(src, 0, idx) + self.assertEqual(dest.shape, (5, 4, 5)) + for i in range(idx.size(0)): + self.assertEqual(dest[i], src[idx[i]]) def test_take_empty(self, device): for input_shape in [(0,), (0, 1, 2, 0), (1, 2, 3)]: diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 9339d805c1b..3cfc8a17c82 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -1678,7 +1678,7 @@ Note: Args: dim (int): dimension along which to index - index (LongTensor): indices of :attr:`tensor` to select from + index (IntTensor or LongTensor): indices of :attr:`tensor` to select from tensor (Tensor): the tensor containing values to add Example:: diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index e3fc7acfa16..91b70bbbd06 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -3410,7 +3410,7 @@ of :attr:`index`; other dimensions have the same size as in the original tensor. Args: {input} dim (int): the dimension in which we index - index (LongTensor): the 1-D tensor containing the indices to index + index (IntTensor or LongTensor): the 1-D tensor containing the indices to index Keyword args: {out} diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 031d974d497..e78fa40f95d 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -1940,7 +1940,7 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, " fixed length sequences. However, found " "offsets of type {}".format(type_str)) offsets = torch.arange(0, input.numel(), input.size(1), - dtype=torch.long, device=input.device) + dtype=input.dtype, device=input.device) input = input.reshape(-1) if per_sample_weights is not None: diff --git a/torch/nn/modules/sparse.py b/torch/nn/modules/sparse.py index 8e8aa556abe..1762297189f 100644 --- a/torch/nn/modules/sparse.py +++ b/torch/nn/modules/sparse.py @@ -34,7 +34,7 @@ class Embedding(Module): initialized from :math:`\mathcal{N}(0, 1)` Shape: - - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract + - Input: :math:`(*)`, IntTensor or LongTensor of arbitrary shape containing the indices to extract - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` .. note:: @@ -54,7 +54,7 @@ class Embedding(Module): When :attr:`max_norm` is not ``None``, :class:`Embedding`'s forward method will modify the :attr:`weight` tensor in-place. Since tensors needed for gradient computations cannot be modified in-place, performing a differentiable operation on ``Embedding.weight`` before - calling :class:`Embedding`'s forward method requires cloning ``Embedding.weight`` when + calling :class:`Embedding`'s forward method requires cloning ``Embedding.weight`` when :attr:`max_norm` is not ``None``. For example:: n, d, m = 3, 5, 7 @@ -62,7 +62,7 @@ class Embedding(Module): W = torch.randn((m, d), requires_grad=True) idx = torch.tensor([1, 2]) a = embedding.weight.clone() @ W.t() # weight must be cloned for this to be differentiable - b = embedding(idx) @ W.t() # modifies weight in-place + b = embedding(idx) @ W.t() # modifies weight in-place out = (a.unsqueeze(0) + b.unsqueeze(1)) loss = out.sigmoid().prod() loss.backward() @@ -246,9 +246,11 @@ class EmbeddingBag(Module): weight (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)` initialized from :math:`\mathcal{N}(0, 1)`. - Inputs: :attr:`input` (LongTensor), :attr:`offsets` (LongTensor, optional), and + Inputs: :attr:`input` (IntTensor or LongTensor), :attr:`offsets` (IntTensor or LongTensor, optional), and :attr:`per_index_weights` (Tensor, optional) + - :attr:`input` and :attr:`offsets` have to be of the same type, either int or long + - If :attr:`input` is 2D of shape `(B, N)`, it will be treated as ``B`` bags (sequences) each of fixed length ``N``, and