mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Support int32 indices and offsets in nn.EmbeddingBag (#46758)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46758 It's in general helpful to support int32 indices and offsets, especially when such tensors are large and need to be transferred to accelerator backends. Since it may not be very useful to support the combination of int32 indices and int64 offsets, here we enforce that these two must have the same type. Test Plan: unit tests Reviewed By: ngimel Differential Revision: D24470808 fbshipit-source-id: 94b8a1d0b7fc9fe3d128247aa042c04d7c227f0b
This commit is contained in:
parent
a2f9c7d4e3
commit
0ec717c830
|
|
@ -557,6 +557,25 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
|
||||||
} \
|
} \
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
#define AT_DISPATCH_INDEX_TYPES(TYPE, NAME, ...) \
|
||||||
|
[&] { \
|
||||||
|
const auto& the_index_type = TYPE; \
|
||||||
|
/* don't use TYPE again in case it is an expensive or side-effect op */ \
|
||||||
|
at::ScalarType _it = ::detail::scalar_type(the_index_type); \
|
||||||
|
switch (_it) { \
|
||||||
|
case at::ScalarType::Int: { \
|
||||||
|
using index_t = int32_t; \
|
||||||
|
return __VA_ARGS__(); \
|
||||||
|
} \
|
||||||
|
case at::ScalarType::Long: { \
|
||||||
|
using index_t = int64_t; \
|
||||||
|
return __VA_ARGS__(); \
|
||||||
|
} \
|
||||||
|
default: \
|
||||||
|
AT_ERROR(#NAME, " not implemented for '", toString(_it), "'"); \
|
||||||
|
} \
|
||||||
|
}()
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// DEPRECATED MACROS, DON'T USE THESE
|
// DEPRECATED MACROS, DON'T USE THESE
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ Tensor embedding(const Tensor & weight, const Tensor & indices,
|
||||||
int64_t padding_idx, bool scale_grad_by_freq, bool sparse) {
|
int64_t padding_idx, bool scale_grad_by_freq, bool sparse) {
|
||||||
TORCH_CHECK(weight.dim() >= 1, "'weight' must be at least 1-D");
|
TORCH_CHECK(weight.dim() >= 1, "'weight' must be at least 1-D");
|
||||||
auto indices_arg = TensorArg(indices, "indices", 1);
|
auto indices_arg = TensorArg(indices, "indices", 1);
|
||||||
checkScalarType("embedding", indices_arg, kLong);
|
checkScalarTypes("embedding", indices_arg, {kLong, kInt});
|
||||||
|
|
||||||
auto zerofill_padding = [&](Tensor& embedding) {
|
auto zerofill_padding = [&](Tensor& embedding) {
|
||||||
if (padding_idx >= 0) {
|
if (padding_idx >= 0) {
|
||||||
|
|
@ -57,7 +57,7 @@ Tensor embedding_sparse_backward(
|
||||||
int64_t padding_idx, bool scale_grad_by_freq) {
|
int64_t padding_idx, bool scale_grad_by_freq) {
|
||||||
|
|
||||||
auto indices_arg = TensorArg(indices_, "indices", 2);
|
auto indices_arg = TensorArg(indices_, "indices", 2);
|
||||||
checkScalarType("embedding_backward", indices_arg, kLong);
|
checkScalarTypes("embedding_backward", indices_arg, {kLong, kInt});
|
||||||
|
|
||||||
// TODO: implement scale_grad_by_freq
|
// TODO: implement scale_grad_by_freq
|
||||||
if (scale_grad_by_freq) {
|
if (scale_grad_by_freq) {
|
||||||
|
|
@ -79,14 +79,14 @@ Tensor embedding_sparse_backward(
|
||||||
|
|
||||||
// check if all our grad come from padding_idx
|
// check if all our grad come from padding_idx
|
||||||
if (grad.numel() == 0) {
|
if (grad.numel() == 0) {
|
||||||
return at::_sparse_coo_tensor_unsafe(at::empty({1, 0}, indices_.options()),
|
return at::_sparse_coo_tensor_unsafe(at::empty({1, 0}, indices_.options().dtype(kLong)),
|
||||||
at::empty({0, num_features}, dense_options),
|
at::empty({0, num_features}, dense_options),
|
||||||
weight_size);
|
weight_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto index = indices.reshape({1, -1});
|
auto index = indices.reshape({1, -1});
|
||||||
auto values = grad.reshape({-1, num_features});
|
auto values = grad.reshape({-1, num_features});
|
||||||
return at::_sparse_coo_tensor_unsafe(index, values, weight_size);
|
return at::_sparse_coo_tensor_unsafe(index.to(kLong), values, weight_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor embedding_dense_backward_cpu(
|
Tensor embedding_dense_backward_cpu(
|
||||||
|
|
@ -94,15 +94,19 @@ Tensor embedding_dense_backward_cpu(
|
||||||
int64_t padding_idx, bool scale_grad_by_freq) {
|
int64_t padding_idx, bool scale_grad_by_freq) {
|
||||||
|
|
||||||
auto indices_arg = TensorArg(indices, "indices", 2);
|
auto indices_arg = TensorArg(indices, "indices", 2);
|
||||||
checkScalarType("embedding_backward", indices_arg, kLong);
|
checkScalarTypes("embedding_backward", indices_arg, {kLong, kInt});
|
||||||
|
|
||||||
|
auto grad_weight = at::zeros({num_weights, grad_.size(-1)}, grad_.options());
|
||||||
auto indices_contig = indices.contiguous();
|
auto indices_contig = indices.contiguous();
|
||||||
auto indices_data = indices_contig.data_ptr<int64_t>();
|
|
||||||
int64_t numel = indices.numel();
|
int64_t numel = indices.numel();
|
||||||
|
auto grad = grad_.contiguous().view({numel, grad_.size(-1)});
|
||||||
|
|
||||||
std::unique_ptr<int64_t[]> counts;
|
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cpu", [&] () {
|
||||||
|
auto indices_data = indices_contig.data_ptr<index_t>();
|
||||||
|
|
||||||
|
std::unique_ptr<index_t[]> counts;
|
||||||
if (scale_grad_by_freq) {
|
if (scale_grad_by_freq) {
|
||||||
counts.reset(new int64_t[num_weights]);
|
counts.reset(new index_t[num_weights]);
|
||||||
for (int i = 0; i < numel; i++) {
|
for (int i = 0; i < numel; i++) {
|
||||||
counts[indices_data[i]] = 0;
|
counts[indices_data[i]] = 0;
|
||||||
}
|
}
|
||||||
|
|
@ -111,13 +115,10 @@ Tensor embedding_dense_backward_cpu(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto grad = grad_.contiguous().view({numel, grad_.size(-1)});
|
auto parallel_section = [&](index_t start, index_t end) {
|
||||||
auto grad_weight = at::zeros({num_weights, grad_.size(-1)}, grad_.options());
|
|
||||||
|
|
||||||
auto parallel_section = [&](int64_t start, int64_t end) {
|
|
||||||
for (int64_t i = 0; i < numel; i++) {
|
for (int64_t i = 0; i < numel; i++) {
|
||||||
if (indices_data[i] != padding_idx) {
|
if (indices_data[i] != padding_idx) {
|
||||||
int64_t k = indices_data[i];
|
index_t k = indices_data[i];
|
||||||
if (k >= start && k < end) {
|
if (k >= start && k < end) {
|
||||||
double scale = 1.0;
|
double scale = 1.0;
|
||||||
if (scale_grad_by_freq) {
|
if (scale_grad_by_freq) {
|
||||||
|
|
@ -130,14 +131,11 @@ Tensor embedding_dense_backward_cpu(
|
||||||
};
|
};
|
||||||
|
|
||||||
if (numel > 1000) {
|
if (numel > 1000) {
|
||||||
// The strategy is to parallelize over sections of the vocabulary, so that
|
|
||||||
// thread 1 handles updates to gradWeight[0..nVocab/nThreads]. Every thread
|
|
||||||
// has to traverse the entire input, but the dominating factor is the axpy
|
|
||||||
// BLAS call.
|
|
||||||
at::parallel_for(0, num_weights, 0, parallel_section);
|
at::parallel_for(0, num_weights, 0, parallel_section);
|
||||||
} else {
|
} else {
|
||||||
parallel_section(0, num_weights);
|
parallel_section(0, num_weights);
|
||||||
}
|
}
|
||||||
|
});
|
||||||
|
|
||||||
return grad_weight;
|
return grad_weight;
|
||||||
}
|
}
|
||||||
|
|
@ -147,14 +145,15 @@ Tensor & embedding_renorm_cpu_(
|
||||||
auto self_arg = TensorArg(self, "self", 1);
|
auto self_arg = TensorArg(self, "self", 1);
|
||||||
auto indices_arg = TensorArg(indices, "indices", 2);
|
auto indices_arg = TensorArg(indices, "indices", 2);
|
||||||
checkDim("embedding_renorm_", self_arg, 2);
|
checkDim("embedding_renorm_", self_arg, 2);
|
||||||
checkScalarType("embedding_renorm_", indices_arg, kLong);
|
checkScalarTypes("embedding_renorm_", indices_arg, {kLong, kInt});
|
||||||
|
|
||||||
auto indices_contig = indices.contiguous();
|
auto indices_contig = indices.contiguous();
|
||||||
|
|
||||||
auto num_indices = indices.numel();
|
auto num_indices = indices.numel();
|
||||||
auto data_ptr = indices_contig.data_ptr<int64_t>();
|
|
||||||
auto sorted_indices = std::vector<int64_t>(data_ptr, data_ptr + num_indices);
|
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_renorm_cpu_", [&]() {
|
||||||
std::sort(sorted_indices.begin(), sorted_indices.end(), std::less<int64_t>());
|
auto data_ptr = indices_contig.data_ptr<index_t>();
|
||||||
|
auto sorted_indices = std::vector<index_t>(data_ptr, data_ptr + num_indices);
|
||||||
|
std::sort(sorted_indices.begin(), sorted_indices.end());
|
||||||
|
|
||||||
// Note that we cannot use at::parallel_for here because we perform operations on
|
// Note that we cannot use at::parallel_for here because we perform operations on
|
||||||
// Tensor inside the loop. See github.com/pytorch/pytorch/issues/28370 for more details.
|
// Tensor inside the loop. See github.com/pytorch/pytorch/issues/28370 for more details.
|
||||||
|
|
@ -169,6 +168,7 @@ Tensor & embedding_renorm_cpu_(
|
||||||
row *= scale;
|
row *= scale;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
});
|
||||||
|
|
||||||
return self;
|
return self;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -32,11 +32,11 @@ namespace native {
|
||||||
template<typename scalar_t>
|
template<typename scalar_t>
|
||||||
scalar_t dot_impl(int64_t n, scalar_t *x, int64_t incx, scalar_t *y, int64_t incy);
|
scalar_t dot_impl(int64_t n, scalar_t *x, int64_t incx, scalar_t *y, int64_t incy);
|
||||||
|
|
||||||
static void make_offset2bag(const Tensor &offsets, const Tensor &indices, Tensor& offset2bag) {
|
static void make_offset2bag(const Tensor &offsets, Tensor& offset2bag) {
|
||||||
offset2bag.index_add_(
|
offset2bag.index_add_(
|
||||||
0, offsets, at::ones_like(offsets, LEGACY_CONTIGUOUS_MEMORY_FORMAT)); // offset2bag = [1 0 1 0 1]
|
0, offsets, at::ones_like(offsets, LEGACY_CONTIGUOUS_MEMORY_FORMAT)); // offset2bag = [1 0 1 0 1]
|
||||||
offset2bag[0] -= 1; // offset2bag = [0 0 1 0 1]
|
offset2bag[0] -= 1; // offset2bag = [0 0 1 0 1]
|
||||||
offset2bag = offset2bag.cumsum(0); // offset2bag = [0 0 1 1 2]
|
offset2bag = offset2bag.cumsum(0, offset2bag.scalar_type()); // offset2bag = [0 0 1 1 2]
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
@ -52,18 +52,19 @@ bool isFastPathIndexSelectScale(const Tensor& src, const Tensor& scale, Tensor&
|
||||||
// This function combines index_select (using select_indices as the index) and
|
// This function combines index_select (using select_indices as the index) and
|
||||||
// index_add (using add_indices as the index), without creating an intermediary
|
// index_add (using add_indices as the index), without creating an intermediary
|
||||||
// tensor to hold the selected embeddings
|
// tensor to hold the selected embeddings
|
||||||
template<typename T>
|
template<typename data_t, typename index_t>
|
||||||
void index_select_add(const Tensor &select_indices,
|
typename std::enable_if<!std::is_same<data_t, float>::value, void>::type
|
||||||
|
index_select_add(const Tensor &select_indices,
|
||||||
const Tensor &add_indices,
|
const Tensor &add_indices,
|
||||||
const Tensor &src,
|
const Tensor &src,
|
||||||
Tensor &output,
|
Tensor &output,
|
||||||
const Tensor& /*offsets*/,
|
const Tensor& /*offsets*/,
|
||||||
bool /*include_last_offset*/) {
|
bool /*include_last_offset*/) {
|
||||||
AT_ASSERT(select_indices.numel() == add_indices.numel());
|
AT_ASSERT(select_indices.numel() == add_indices.numel());
|
||||||
auto* add_indices_data = add_indices.data_ptr<int64_t>();
|
auto* add_indices_data = add_indices.data_ptr<index_t>();
|
||||||
auto* select_indices_data = select_indices.data_ptr<int64_t>();
|
auto* select_indices_data = select_indices.data_ptr<index_t>();
|
||||||
auto* src_data = src.data_ptr<T>();
|
auto* src_data = src.data_ptr<data_t>();
|
||||||
auto* output_data = output.data_ptr<T>();
|
auto* output_data = output.data_ptr<data_t>();
|
||||||
auto numel = add_indices.numel();
|
auto numel = add_indices.numel();
|
||||||
int64_t ddim = src.size(1);
|
int64_t ddim = src.size(1);
|
||||||
auto src_stride0 = src.stride(0);
|
auto src_stride0 = src.stride(0);
|
||||||
|
|
@ -72,29 +73,30 @@ void index_select_add(const Tensor &select_indices,
|
||||||
auto output_stride1 = output.stride(1);
|
auto output_stride1 = output.stride(1);
|
||||||
|
|
||||||
for (int64_t i = 0; i < numel; i++) {
|
for (int64_t i = 0; i < numel; i++) {
|
||||||
THBlas_axpy<T>(ddim, 1,
|
THBlas_axpy<data_t>(ddim, 1,
|
||||||
src_data + src_stride0 * select_indices_data[i], src_stride1,
|
src_data + src_stride0 * select_indices_data[i], src_stride1,
|
||||||
output_data + output_stride0 * add_indices_data[i], output_stride1);
|
output_data + output_stride0 * add_indices_data[i], output_stride1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template<typename data_t, typename index_t>
|
||||||
void index_select_add<float>(const Tensor &select_indices,
|
typename std::enable_if<std::is_same<data_t, float>::value, void>::type
|
||||||
|
index_select_add(const Tensor &select_indices,
|
||||||
const Tensor &add_indices,
|
const Tensor &add_indices,
|
||||||
const Tensor &src,
|
const Tensor &src,
|
||||||
Tensor &output,
|
Tensor &output,
|
||||||
const Tensor& offsets,
|
const Tensor& offsets,
|
||||||
bool include_last_offset) {
|
bool include_last_offset) {
|
||||||
int64_t ddim = src.size(1);
|
int64_t ddim = src.size(1);
|
||||||
auto* select_indices_data = select_indices.data_ptr<int64_t>();
|
auto* select_indices_data = select_indices.data_ptr<index_t>();
|
||||||
auto* output_data = output.data_ptr<float>();
|
auto* output_data = output.data_ptr<float>();
|
||||||
|
|
||||||
if (isFastPathIndexSelect(src, output)) {
|
if (isFastPathIndexSelect(src, output)) {
|
||||||
auto src_contig = src.contiguous();
|
auto src_contig = src.contiguous();
|
||||||
auto* src_data = src_contig.data_ptr<float>();
|
auto* src_data = src_contig.data_ptr<float>();
|
||||||
int64_t output_size = offsets.numel() - 1;
|
int64_t output_size = offsets.numel() - 1;
|
||||||
auto* offsets_data = offsets.data_ptr<int64_t>();
|
auto* offsets_data = offsets.data_ptr<index_t>();
|
||||||
std::vector<int64_t> offsets_include_last;
|
std::vector<index_t> offsets_include_last;
|
||||||
|
|
||||||
if (include_last_offset) {
|
if (include_last_offset) {
|
||||||
output_size = offsets.numel() - 1;
|
output_size = offsets.numel() - 1;
|
||||||
|
|
@ -103,15 +105,15 @@ void index_select_add<float>(const Tensor &select_indices,
|
||||||
offsets_include_last.resize(offsets.numel() + 1);
|
offsets_include_last.resize(offsets.numel() + 1);
|
||||||
std::memcpy(
|
std::memcpy(
|
||||||
offsets_include_last.data(),
|
offsets_include_last.data(),
|
||||||
offsets.data_ptr<int64_t>(),
|
offsets.data_ptr<index_t>(),
|
||||||
sizeof(int64_t) * offsets.numel());
|
sizeof(index_t) * offsets.numel());
|
||||||
offsets_include_last[offsets.numel()] = select_indices.numel();
|
offsets_include_last[offsets.numel()] = select_indices.numel();
|
||||||
offsets_data = offsets_include_last.data();
|
offsets_data = offsets_include_last.data();
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef USE_FBGEMM
|
#ifdef USE_FBGEMM
|
||||||
auto kernel_fp32_i64 =
|
auto kernel_fp32_index_t =
|
||||||
fbgemm::GenerateEmbeddingSpMDM<float, int64_t, int64_t>(
|
fbgemm::GenerateEmbeddingSpMDM<float, index_t, index_t>(
|
||||||
/* block_size */ddim,
|
/* block_size */ddim,
|
||||||
/* has_weight */false,
|
/* has_weight */false,
|
||||||
/* normalize_by_lengths */false,
|
/* normalize_by_lengths */false,
|
||||||
|
|
@ -121,9 +123,9 @@ void index_select_add<float>(const Tensor &select_indices,
|
||||||
);
|
);
|
||||||
#endif
|
#endif
|
||||||
at::parallel_for(
|
at::parallel_for(
|
||||||
0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) {
|
0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
|
||||||
#ifdef USE_FBGEMM
|
#ifdef USE_FBGEMM
|
||||||
kernel_fp32_i64(
|
kernel_fp32_index_t(
|
||||||
/* output_size */end_idx - start_idx,
|
/* output_size */end_idx - start_idx,
|
||||||
/* index_size */offsets_data[end_idx] - offsets_data[start_idx],
|
/* index_size */offsets_data[end_idx] - offsets_data[start_idx],
|
||||||
/* data_size */src.size(0),
|
/* data_size */src.size(0),
|
||||||
|
|
@ -150,7 +152,7 @@ void index_select_add<float>(const Tensor &select_indices,
|
||||||
} else {
|
} else {
|
||||||
AT_ASSERT(select_indices.numel() == add_indices.numel());
|
AT_ASSERT(select_indices.numel() == add_indices.numel());
|
||||||
auto* src_data = src.data_ptr<float>();
|
auto* src_data = src.data_ptr<float>();
|
||||||
auto* add_indices_data = add_indices.data_ptr<int64_t>();
|
auto* add_indices_data = add_indices.data_ptr<index_t>();
|
||||||
auto src_stride0 = src.stride(0);
|
auto src_stride0 = src.stride(0);
|
||||||
auto src_stride1 = src.stride(1);
|
auto src_stride1 = src.stride(1);
|
||||||
auto output_stride0 = output.stride(0);
|
auto output_stride0 = output.stride(0);
|
||||||
|
|
@ -172,8 +174,9 @@ void index_select_add<float>(const Tensor &select_indices,
|
||||||
// index_select (using select_indices as the index)
|
// index_select (using select_indices as the index)
|
||||||
// mul (scaling by per_sample_weights)
|
// mul (scaling by per_sample_weights)
|
||||||
// index_add (using add_indices as the index)
|
// index_add (using add_indices as the index)
|
||||||
template<typename T>
|
template<typename data_t, typename index_t>
|
||||||
static void index_select_scale_add(const Tensor &select_indices,
|
static typename std::enable_if<!std::is_same<data_t, float>::value, void>::type
|
||||||
|
index_select_scale_add(const Tensor &select_indices,
|
||||||
const Tensor &add_indices,
|
const Tensor &add_indices,
|
||||||
const Tensor &scale,
|
const Tensor &scale,
|
||||||
const Tensor &src,
|
const Tensor &src,
|
||||||
|
|
@ -181,10 +184,10 @@ static void index_select_scale_add(const Tensor &select_indices,
|
||||||
const Tensor& /*offsets*/,
|
const Tensor& /*offsets*/,
|
||||||
bool /*include_last_offset*/) {
|
bool /*include_last_offset*/) {
|
||||||
AT_ASSERT(select_indices.numel() == add_indices.numel());
|
AT_ASSERT(select_indices.numel() == add_indices.numel());
|
||||||
auto* add_indices_data = add_indices.data_ptr<int64_t>();
|
auto* add_indices_data = add_indices.data_ptr<index_t>();
|
||||||
auto* select_indices_data = select_indices.data_ptr<int64_t>();
|
auto* select_indices_data = select_indices.data_ptr<index_t>();
|
||||||
auto* src_data = src.data_ptr<T>();
|
auto* src_data = src.data_ptr<data_t>();
|
||||||
auto* output_data = output.data_ptr<T>();
|
auto* output_data = output.data_ptr<data_t>();
|
||||||
auto numel = add_indices.numel();
|
auto numel = add_indices.numel();
|
||||||
int64_t ddim = src.size(1);
|
int64_t ddim = src.size(1);
|
||||||
auto src_stride0 = src.stride(0);
|
auto src_stride0 = src.stride(0);
|
||||||
|
|
@ -192,7 +195,7 @@ static void index_select_scale_add(const Tensor &select_indices,
|
||||||
auto output_stride0 = output.stride(0);
|
auto output_stride0 = output.stride(0);
|
||||||
auto output_stride1 = output.stride(1);
|
auto output_stride1 = output.stride(1);
|
||||||
|
|
||||||
auto* scale_data = scale.data_ptr<T>();
|
auto* scale_data = scale.data_ptr<data_t>();
|
||||||
auto scale_stride = scale.stride(0);
|
auto scale_stride = scale.stride(0);
|
||||||
|
|
||||||
for (int64_t i = 0; i < numel; i++) {
|
for (int64_t i = 0; i < numel; i++) {
|
||||||
|
|
@ -205,8 +208,9 @@ static void index_select_scale_add(const Tensor &select_indices,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template<typename data_t, typename index_t>
|
||||||
void index_select_scale_add<float>(const Tensor &select_indices,
|
typename std::enable_if<std::is_same<data_t, float>::value, void>::type
|
||||||
|
index_select_scale_add(const Tensor &select_indices,
|
||||||
const Tensor &add_indices,
|
const Tensor &add_indices,
|
||||||
const Tensor &scale,
|
const Tensor &scale,
|
||||||
const Tensor &src,
|
const Tensor &src,
|
||||||
|
|
@ -215,15 +219,15 @@ void index_select_scale_add<float>(const Tensor &select_indices,
|
||||||
bool include_last_offset) {
|
bool include_last_offset) {
|
||||||
int64_t ddim = src.size(1);
|
int64_t ddim = src.size(1);
|
||||||
auto* scale_data = scale.data_ptr<float>();
|
auto* scale_data = scale.data_ptr<float>();
|
||||||
auto* select_indices_data = select_indices.data_ptr<int64_t>();
|
auto* select_indices_data = select_indices.data_ptr<index_t>();
|
||||||
auto* output_data = output.data_ptr<float>();
|
auto* output_data = output.data_ptr<float>();
|
||||||
|
|
||||||
if (isFastPathIndexSelectScale(src, scale, output)) {
|
if (isFastPathIndexSelectScale(src, scale, output)) {
|
||||||
auto src_contig = src.contiguous();
|
auto src_contig = src.contiguous();
|
||||||
auto* src_data = src_contig.data_ptr<float>();
|
auto* src_data = src_contig.data_ptr<float>();
|
||||||
int64_t output_size = offsets.numel() - 1;
|
int64_t output_size = offsets.numel() - 1;
|
||||||
auto* offsets_data = offsets.data_ptr<int64_t>();
|
auto* offsets_data = offsets.data_ptr<index_t>();
|
||||||
std::vector<int64_t> offsets_include_last;
|
std::vector<index_t> offsets_include_last;
|
||||||
|
|
||||||
if (include_last_offset) {
|
if (include_last_offset) {
|
||||||
output_size = offsets.numel() - 1;
|
output_size = offsets.numel() - 1;
|
||||||
|
|
@ -232,15 +236,15 @@ void index_select_scale_add<float>(const Tensor &select_indices,
|
||||||
offsets_include_last.resize(offsets.numel() + 1);
|
offsets_include_last.resize(offsets.numel() + 1);
|
||||||
std::memcpy(
|
std::memcpy(
|
||||||
offsets_include_last.data(),
|
offsets_include_last.data(),
|
||||||
offsets.data_ptr<int64_t>(),
|
offsets.data_ptr<index_t>(),
|
||||||
sizeof(int64_t) * offsets.numel());
|
sizeof(index_t) * offsets.numel());
|
||||||
offsets_include_last[offsets.numel()] = select_indices.numel();
|
offsets_include_last[offsets.numel()] = select_indices.numel();
|
||||||
offsets_data = offsets_include_last.data();
|
offsets_data = offsets_include_last.data();
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef USE_FBGEMM
|
#ifdef USE_FBGEMM
|
||||||
auto kernel_fp32_i64 =
|
auto kernel_fp32_index_t =
|
||||||
fbgemm::GenerateEmbeddingSpMDM<float, int64_t, int64_t>(
|
fbgemm::GenerateEmbeddingSpMDM<float, index_t, index_t>(
|
||||||
/* block_size */ddim,
|
/* block_size */ddim,
|
||||||
/* has_weight */true,
|
/* has_weight */true,
|
||||||
/* normalize_by_lengths */false,
|
/* normalize_by_lengths */false,
|
||||||
|
|
@ -250,9 +254,9 @@ void index_select_scale_add<float>(const Tensor &select_indices,
|
||||||
);
|
);
|
||||||
#endif
|
#endif
|
||||||
at::parallel_for(
|
at::parallel_for(
|
||||||
0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) {
|
0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
|
||||||
#ifdef USE_FBGEMM
|
#ifdef USE_FBGEMM
|
||||||
kernel_fp32_i64(
|
kernel_fp32_index_t(
|
||||||
/* output_size */end_idx - start_idx,
|
/* output_size */end_idx - start_idx,
|
||||||
/* index_size */offsets_data[end_idx] - offsets_data[start_idx],
|
/* index_size */offsets_data[end_idx] - offsets_data[start_idx],
|
||||||
/* data_size */src.size(0),
|
/* data_size */src.size(0),
|
||||||
|
|
@ -279,7 +283,7 @@ void index_select_scale_add<float>(const Tensor &select_indices,
|
||||||
} else {
|
} else {
|
||||||
AT_ASSERT(select_indices.numel() == add_indices.numel());
|
AT_ASSERT(select_indices.numel() == add_indices.numel());
|
||||||
auto* src_data = src.data_ptr<float>();
|
auto* src_data = src.data_ptr<float>();
|
||||||
auto* add_indices_data = add_indices.data_ptr<int64_t>();
|
auto* add_indices_data = add_indices.data_ptr<index_t>();
|
||||||
auto src_stride0 = src.stride(0);
|
auto src_stride0 = src.stride(0);
|
||||||
auto src_stride1 = src.stride(1);
|
auto src_stride1 = src.stride(1);
|
||||||
auto output_stride0 = output.stride(0);
|
auto output_stride0 = output.stride(0);
|
||||||
|
|
@ -308,7 +312,7 @@ static at::Tensor make_bag_size(
|
||||||
const bool requires_grad) {
|
const bool requires_grad) {
|
||||||
at::Tensor bag_size;
|
at::Tensor bag_size;
|
||||||
if (mode == MODE_MEAN || mode == MODE_MAX) {
|
if (mode == MODE_MEAN || mode == MODE_MAX) {
|
||||||
bag_size = at::zeros(offsets.sizes(), indices.options());
|
bag_size = at::zeros(offsets.sizes(), offsets.options());
|
||||||
// Compute this for MODE_MEAN and MODE_MAX (latter needed for backwards)
|
// Compute this for MODE_MEAN and MODE_MAX (latter needed for backwards)
|
||||||
if (offsets.size(0) != 1) {
|
if (offsets.size(0) != 1) {
|
||||||
bag_size.slice(0, 0, bag_size.size(0) - 1, 1) =
|
bag_size.slice(0, 0, bag_size.size(0) - 1, 1) =
|
||||||
|
|
@ -318,7 +322,7 @@ static at::Tensor make_bag_size(
|
||||||
bag_size[-1] = indices.size(0) - offsets[-1];
|
bag_size[-1] = indices.size(0) - offsets[-1];
|
||||||
} else if (requires_grad) {
|
} else if (requires_grad) {
|
||||||
// in MODE_SUM, only allocate bag_size if we need gradients
|
// in MODE_SUM, only allocate bag_size if we need gradients
|
||||||
bag_size = at::empty(offsets.sizes(), indices.options());
|
bag_size = at::empty(offsets.sizes(), offsets.options());
|
||||||
}
|
}
|
||||||
return bag_size;
|
return bag_size;
|
||||||
}
|
}
|
||||||
|
|
@ -384,11 +388,11 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> embedding_bag_cpu_max(
|
||||||
}
|
}
|
||||||
auto max_indices =
|
auto max_indices =
|
||||||
at::zeros({numBags, featureSize}, indices.options());
|
at::zeros({numBags, featureSize}, indices.options());
|
||||||
|
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_cpu_max", [&] {
|
||||||
|
auto* indices_data = indices.data_ptr<index_t>();
|
||||||
|
auto* offset2bag_data = offset2bag.data_ptr<index_t>();
|
||||||
|
|
||||||
auto* indices_data = indices.data_ptr<int64_t>();
|
auto* max_indices_data = max_indices.data_ptr<index_t>();
|
||||||
auto* offset2bag_data = offset2bag.data_ptr<int64_t>();
|
|
||||||
|
|
||||||
auto* max_indices_data = max_indices.data_ptr<int64_t>();
|
|
||||||
auto max_indices_stride = max_indices.stride(0);
|
auto max_indices_stride = max_indices.stride(0);
|
||||||
|
|
||||||
auto* weight_data = weight.data_ptr<scalar_t>();
|
auto* weight_data = weight.data_ptr<scalar_t>();
|
||||||
|
|
@ -397,7 +401,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> embedding_bag_cpu_max(
|
||||||
auto weight_stride1 = weight.stride(1);
|
auto weight_stride1 = weight.stride(1);
|
||||||
auto output_stride = output.stride(0);
|
auto output_stride = output.stride(0);
|
||||||
|
|
||||||
for (int i = 0; i < numIndices; i++) {
|
for (int i = 0; i < numIndices; ++i) {
|
||||||
auto bag = offset2bag_data[i];
|
auto bag = offset2bag_data[i];
|
||||||
auto word_idx = indices_data[i];
|
auto word_idx = indices_data[i];
|
||||||
|
|
||||||
|
|
@ -413,6 +417,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> embedding_bag_cpu_max(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
});
|
||||||
|
|
||||||
return std::tuple<Tensor, Tensor, Tensor, Tensor>(
|
return std::tuple<Tensor, Tensor, Tensor, Tensor>(
|
||||||
output, offset2bag, bag_size, max_indices);
|
output, offset2bag, bag_size, max_indices);
|
||||||
|
|
@ -429,19 +434,23 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_cpu_impl(
|
||||||
bool include_last_offset,
|
bool include_last_offset,
|
||||||
bool requires_grad) {
|
bool requires_grad) {
|
||||||
auto indices_arg = TensorArg(indices, "indices", 1);
|
auto indices_arg = TensorArg(indices, "indices", 1);
|
||||||
checkScalarType("embedding_bag", indices_arg, kLong);
|
checkScalarTypes("embedding_bag", indices_arg, {kLong, kInt});
|
||||||
auto offsets_arg = TensorArg(offsets, "offsets", 1);
|
auto offsets_arg = TensorArg(offsets, "offsets", 1);
|
||||||
checkScalarType("embedding_bag", offsets_arg, kLong);
|
checkScalarTypes("embedding_bag", offsets_arg, {kLong, kInt});
|
||||||
|
checkSameType("embedding_bag", indices_arg, offsets_arg);
|
||||||
auto weight_arg = TensorArg(weight, "weight", 1);
|
auto weight_arg = TensorArg(weight, "weight", 1);
|
||||||
checkScalarTypes("embedding_bag", weight_arg, {kFloat, kDouble});
|
checkScalarTypes("embedding_bag", weight_arg, {kFloat, kDouble});
|
||||||
int64_t offset_0 = offsets.data_ptr<int64_t>()[0];
|
|
||||||
int64_t offset_n = offsets.data_ptr<int64_t>()[offsets.size(0)-1];
|
AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "_embedding_bag_cpu_impl", [&]() {
|
||||||
|
index_t offset_0 = offsets.data_ptr<index_t>()[0];
|
||||||
|
index_t offset_n = offsets.data_ptr<index_t>()[offsets.size(0)-1];
|
||||||
TORCH_CHECK(offset_0 == 0, "offsets[0] has to be 0, i.e., the first sequence "
|
TORCH_CHECK(offset_0 == 0, "offsets[0] has to be 0, i.e., the first sequence "
|
||||||
"in the mini-batch has to start from position 0. "
|
"in the mini-batch has to start from position 0. "
|
||||||
"However, got ", offsets[0]);
|
"However, got ", offsets[0]);
|
||||||
TORCH_CHECK(offset_n <= indices.size(0), "offsets[-1] can not "
|
TORCH_CHECK(offset_n <= indices.size(0), "offsets[-1] can not "
|
||||||
"be greater than input's length ", indices.size(0), " but got offsets[-1] of ",
|
"be greater than input's length ", indices.size(0), " but got offsets[-1] of ",
|
||||||
offset_n);
|
offset_n);
|
||||||
|
});
|
||||||
|
|
||||||
if (per_sample_weights.defined()) {
|
if (per_sample_weights.defined()) {
|
||||||
TORCH_CHECK(mode == MODE_SUM,
|
TORCH_CHECK(mode == MODE_SUM,
|
||||||
|
|
@ -494,9 +503,9 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_cpu_impl(
|
||||||
// throw out of bounds error. So to keep it simple we just add one more
|
// throw out of bounds error. So to keep it simple we just add one more
|
||||||
// entry to the end then get rid of it after make_offset2bag.
|
// entry to the end then get rid of it after make_offset2bag.
|
||||||
offset2bag = at::zeros(
|
offset2bag = at::zeros(
|
||||||
{indices.sizes()[0] + 1}, indices.options()); // offset2bag = [0 0 0 0 0]
|
{indices.sizes()[0] + 1}, offsets.options()); // offset2bag = [0 0 0 0 0]
|
||||||
|
|
||||||
make_offset2bag(offsets, indices, offset2bag);
|
make_offset2bag(offsets, offset2bag);
|
||||||
|
|
||||||
offset2bag.resize_({indices.sizes()[0]});
|
offset2bag.resize_({indices.sizes()[0]});
|
||||||
|
|
||||||
|
|
@ -505,15 +514,21 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_cpu_impl(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mode == MODE_MEAN || mode == MODE_SUM) {
|
if (mode == MODE_MEAN || mode == MODE_SUM) {
|
||||||
AT_DISPATCH_FLOATING_TYPES(weight.scalar_type(), "embedding_bag_cpu", [&]() {
|
// explicitly capture all required variables to work around windows build
|
||||||
|
// TODO: fix this when windows can correctly capture variables in nested lambda
|
||||||
|
AT_DISPATCH_FLOATING_TYPES(weight.scalar_type(), "embedding_bag_cpu",
|
||||||
|
[&indices, &offset2bag, &per_sample_weights, &weight, &output, &offsets, &include_last_offset, &mode]() {
|
||||||
|
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_cpu",
|
||||||
|
[&indices, &offset2bag, &per_sample_weights, &weight, &output, &offsets, &include_last_offset, &mode]() {
|
||||||
if (per_sample_weights.defined()) {
|
if (per_sample_weights.defined()) {
|
||||||
AT_ASSERT(mode == MODE_SUM);
|
AT_ASSERT(mode == MODE_SUM);
|
||||||
index_select_scale_add<scalar_t>(
|
index_select_scale_add<scalar_t, index_t>(
|
||||||
indices, offset2bag, per_sample_weights, weight, output, offsets, include_last_offset);
|
indices, offset2bag, per_sample_weights, weight, output, offsets, include_last_offset);
|
||||||
} else {
|
} else {
|
||||||
index_select_add<scalar_t>(indices, offset2bag, weight, output, offsets, include_last_offset);
|
index_select_add<scalar_t, index_t>(indices, offset2bag, weight, output, offsets, include_last_offset);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
});
|
||||||
auto ret = apply_bag_size(offsets, indices, mode, output, bag_size);
|
auto ret = apply_bag_size(offsets, indices, mode, output, bag_size);
|
||||||
return std::tuple<Tensor, Tensor, Tensor, Tensor>(ret, offset2bag, bag_size, bag_size);
|
return std::tuple<Tensor, Tensor, Tensor, Tensor>(ret, offset2bag, bag_size, bag_size);
|
||||||
} else { // MODE_MAX
|
} else { // MODE_MAX
|
||||||
|
|
@ -598,23 +613,24 @@ Tensor _embedding_bag_backward(const Tensor &grad, const Tensor &indices,
|
||||||
bool sparse,
|
bool sparse,
|
||||||
const Tensor& per_sample_weights) {
|
const Tensor& per_sample_weights) {
|
||||||
auto indices_arg = TensorArg(indices, "indices", 1);
|
auto indices_arg = TensorArg(indices, "indices", 1);
|
||||||
checkScalarType("embedding_bag", indices_arg, kLong);
|
checkScalarTypes("embedding_bag", indices_arg, {kLong, kInt});
|
||||||
checkContiguous("embedding_bag", indices_arg);
|
checkContiguous("embedding_bag", indices_arg);
|
||||||
auto offsets_arg = TensorArg(offsets, "offsets", 1);
|
auto offsets_arg = TensorArg(offsets, "offsets", 1);
|
||||||
checkScalarType("embedding_bag", offsets_arg, kLong);
|
checkScalarTypes("embedding_bag", offsets_arg, {kLong, kInt});
|
||||||
|
checkSameType("embedding_bag", indices_arg, offsets_arg);
|
||||||
checkContiguous("embedding_bag", offsets_arg);
|
checkContiguous("embedding_bag", offsets_arg);
|
||||||
|
|
||||||
Tensor offset2bag_;
|
Tensor offset2bag_;
|
||||||
if (indices.numel() != 0 && offset2bag.numel() == 0) {
|
if (indices.numel() != 0 && offset2bag.numel() == 0) {
|
||||||
offset2bag_ = at::zeros(
|
offset2bag_ = at::zeros(
|
||||||
{indices.sizes()[0] + 1}, indices.options()); // offset2bag = [0 0 0 0 0]
|
{indices.sizes()[0] + 1}, offsets.options()); // offset2bag = [0 0 0 0 0]
|
||||||
|
|
||||||
make_offset2bag(offsets, indices, offset2bag_);
|
make_offset2bag(offsets, offset2bag_);
|
||||||
|
|
||||||
offset2bag_.resize_({indices.sizes()[0]});
|
offset2bag_.resize_({indices.sizes()[0]});
|
||||||
} else {
|
} else {
|
||||||
auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1);
|
auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1);
|
||||||
checkScalarType("embedding_bag", offset2bag_arg, kLong);
|
checkScalarTypes("embedding_bag", offset2bag_arg, {kLong, kInt});
|
||||||
checkContiguous("embedding_bag", offset2bag_arg);
|
checkContiguous("embedding_bag", offset2bag_arg);
|
||||||
offset2bag_ = offset2bag;
|
offset2bag_ = offset2bag;
|
||||||
}
|
}
|
||||||
|
|
@ -648,11 +664,12 @@ static Tensor _embedding_bag_dense_backward_cpu_max(
|
||||||
return index_grad_weight;
|
return index_grad_weight;
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<int64_t> compute_counts(
|
template<typename index_t>
|
||||||
|
static std::vector<index_t> compute_counts(
|
||||||
int64_t num_weights,
|
int64_t num_weights,
|
||||||
int64_t* indices_data,
|
index_t* indices_data,
|
||||||
int64_t indices_length) {
|
int64_t indices_length) {
|
||||||
std::vector<int64_t> counts(num_weights, 0);
|
std::vector<index_t> counts(num_weights, 0);
|
||||||
for (int i = 0; i < indices_length; i++) {
|
for (int i = 0; i < indices_length; i++) {
|
||||||
counts[indices_data[i]]++;
|
counts[indices_data[i]]++;
|
||||||
}
|
}
|
||||||
|
|
@ -668,12 +685,13 @@ static std::vector<int64_t> compute_counts(
|
||||||
// counts_uniq: [3, 4, 6, 7]
|
// counts_uniq: [3, 4, 6, 7]
|
||||||
//
|
//
|
||||||
// The unique indices can be found at index 0, 3, 4, 6.
|
// The unique indices can be found at index 0, 3, 4, 6.
|
||||||
static std::vector<int64_t> compute_counts_uniq(
|
template<typename index_t>
|
||||||
|
static std::vector<index_t> compute_counts_uniq(
|
||||||
int64_t num_weights,
|
int64_t num_weights,
|
||||||
int64_t* indices_data,
|
index_t* indices_data,
|
||||||
int64_t indices_length,
|
int64_t indices_length,
|
||||||
const std::vector<int64_t>& counts) {
|
const std::vector<index_t>& counts) {
|
||||||
std::vector<int64_t> counts_uniq;
|
std::vector<index_t> counts_uniq;
|
||||||
counts_uniq.reserve(num_weights);
|
counts_uniq.reserve(num_weights);
|
||||||
int64_t o = 0;
|
int64_t o = 0;
|
||||||
for (int64_t i = 0; i < indices_length; i += counts[indices_data[i]]) {
|
for (int64_t i = 0; i < indices_length; i += counts[indices_data[i]]) {
|
||||||
|
|
@ -714,21 +732,31 @@ void _embedding_bag_dense_backward_cpu_sum_mean(
|
||||||
per_sample_weights_stride = per_sample_weights->stride(0);
|
per_sample_weights_stride = per_sample_weights->stride(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto* indices_data = indices.data_ptr<int64_t>();
|
|
||||||
auto* offsets_data = offsets_.data_ptr<int64_t>();
|
|
||||||
auto* offset2bag_data = offset2bag.data_ptr<int64_t>();
|
|
||||||
int64_t numel = indices.numel();
|
int64_t numel = indices.numel();
|
||||||
|
|
||||||
|
// explicitly capture all required variables to work around windows build
|
||||||
|
// TODO: fix this when windows can correctly capture variables in nested lambda
|
||||||
|
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "_embedding_bag_dense_backward_cpu_sum_mean",
|
||||||
|
[&indices, &offsets_, &offset2bag, &num_weights, &numel, &per_sample_weights,
|
||||||
|
&per_sample_weights_data, &per_sample_weights_stride, &mode, &scale_grad_by_freq,
|
||||||
|
&grad, &index_grad_weight] {
|
||||||
|
auto* indices_data = indices.data_ptr<index_t>();
|
||||||
|
auto* offsets_data = offsets_.data_ptr<index_t>();
|
||||||
|
auto* offset2bag_data = offset2bag.data_ptr<index_t>();
|
||||||
|
|
||||||
auto counts = compute_counts(num_weights, indices_data, numel);
|
auto counts = compute_counts(num_weights, indices_data, numel);
|
||||||
auto next_unique_index_idx =
|
auto next_unique_index_idx =
|
||||||
compute_counts_uniq(num_weights, indices_data, numel, counts);
|
compute_counts_uniq(num_weights, indices_data, numel, counts);
|
||||||
|
|
||||||
auto loop = [&](int64_t start, int64_t end) {
|
auto loop =
|
||||||
for (int64_t i = start; i < end; i++) {
|
[&next_unique_index_idx, &indices_data, &offset2bag_data, &per_sample_weights,
|
||||||
int64_t start = i == 0 ? 0 : next_unique_index_idx[i - 1];
|
&mode, &per_sample_weights_data, &per_sample_weights_stride, &scale_grad_by_freq,
|
||||||
int64_t index = indices_data[start];
|
&counts, &offsets_, &indices, &offsets_data, &grad, &index_grad_weight](index_t start, index_t end) {
|
||||||
for (int64_t j = start; j < next_unique_index_idx[i]; j++) {
|
for (index_t i = start; i < end; i++) {
|
||||||
int64_t source = offset2bag_data[j];
|
index_t start = i == 0 ? 0 : next_unique_index_idx[i - 1];
|
||||||
|
index_t index = indices_data[start];
|
||||||
|
for (index_t j = start; j < next_unique_index_idx[i]; j++) {
|
||||||
|
index_t source = offset2bag_data[j];
|
||||||
double scale = 1.0;
|
double scale = 1.0;
|
||||||
if (per_sample_weights) {
|
if (per_sample_weights) {
|
||||||
AT_ASSERT(mode == MODE_SUM);
|
AT_ASSERT(mode == MODE_SUM);
|
||||||
|
|
@ -757,11 +785,13 @@ void _embedding_bag_dense_backward_cpu_sum_mean(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if (numel > 1000) {
|
if (numel > 1000) {
|
||||||
at::parallel_for(0, (int64_t)next_unique_index_idx.size(), 0, loop);
|
at::parallel_for(0, (int64_t)next_unique_index_idx.size(), 0, loop);
|
||||||
} else {
|
} else {
|
||||||
loop(0, (int64_t)next_unique_index_idx.size());
|
loop(0, (int64_t)next_unique_index_idx.size());
|
||||||
}
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor _embedding_bag_dense_backward_cpu(const Tensor &grad_, const Tensor &indices_,
|
Tensor _embedding_bag_dense_backward_cpu(const Tensor &grad_, const Tensor &indices_,
|
||||||
|
|
@ -820,20 +850,20 @@ Tensor _embedding_bag_per_sample_weights_backward_cpu_template(
|
||||||
auto output = at::zeros({num_samples}, grad.options());
|
auto output = at::zeros({num_samples}, grad.options());
|
||||||
|
|
||||||
auto indices_arg = TensorArg(indices, "indices", 1);
|
auto indices_arg = TensorArg(indices, "indices", 1);
|
||||||
checkScalarType("embedding_bag", indices_arg, kLong);
|
checkScalarTypes("embedding_bag", indices_arg, {kLong, kInt});
|
||||||
checkContiguous("embedding_bag", indices_arg);
|
checkContiguous("embedding_bag", indices_arg);
|
||||||
|
|
||||||
Tensor offset2bag_;
|
Tensor offset2bag_;
|
||||||
if (indices.numel() != 0 && offset2bag.numel() == 0) {
|
if (indices.numel() != 0 && offset2bag.numel() == 0) {
|
||||||
offset2bag_ = at::zeros(
|
offset2bag_ = at::zeros(
|
||||||
{indices.sizes()[0] + 1}, indices.options()); // offset2bag = [0 0 0 0 0]
|
{indices.sizes()[0] + 1}, offset2bag.options()); // offset2bag = [0 0 0 0 0]
|
||||||
|
|
||||||
make_offset2bag(offsets, indices, offset2bag_);
|
make_offset2bag(offsets, offset2bag_);
|
||||||
|
|
||||||
offset2bag_.resize_({indices.sizes()[0]});
|
offset2bag_.resize_({indices.sizes()[0]});
|
||||||
} else {
|
} else {
|
||||||
auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1);
|
auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1);
|
||||||
checkScalarType("embedding_bag", offset2bag_arg, kLong);
|
checkScalarTypes("embedding_bag", offset2bag_arg, {kLong, kInt});
|
||||||
checkContiguous("embedding_bag", offset2bag_arg);
|
checkContiguous("embedding_bag", offset2bag_arg);
|
||||||
offset2bag_ = offset2bag;
|
offset2bag_ = offset2bag;
|
||||||
}
|
}
|
||||||
|
|
@ -846,15 +876,22 @@ Tensor _embedding_bag_per_sample_weights_backward_cpu_template(
|
||||||
auto weight_stride0 = weight.stride(0);
|
auto weight_stride0 = weight.stride(0);
|
||||||
auto weight_stride1 = weight.stride(1);
|
auto weight_stride1 = weight.stride(1);
|
||||||
|
|
||||||
auto* indices_data = indices.data_ptr<int64_t>();
|
// explicitly capture all required variables to work around windows build
|
||||||
|
// TODO: fix this when windows can correctly capture variables in nested lambda
|
||||||
|
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "_embedding_bag_per_sample_weights_backward_cpu_template",
|
||||||
|
[&indices, &output, &offset2bag_, &num_samples, &embedding_features,
|
||||||
|
&grad_data, &grad_stride0, &grad_stride1, &weight_data, &weight_stride0, &weight_stride1] () {
|
||||||
|
auto* indices_data = indices.data_ptr<index_t>();
|
||||||
|
|
||||||
// The following are contiguous
|
// The following are contiguous
|
||||||
auto* output_data = output.data_ptr<scalar_t>();
|
auto* output_data = output.data_ptr<scalar_t>();
|
||||||
auto* offset2bag_data = offset2bag_.data_ptr<int64_t>();
|
auto* offset2bag_data = offset2bag_.data_ptr<index_t>();
|
||||||
|
|
||||||
// XXX: 64 was arbitrarily chosen. There is probably a sweet spot for this number.
|
// XXX: 64 was arbitrarily chosen. There is probably a sweet spot for this number.
|
||||||
parallel_for(0, num_samples, 64, [&](int64_t begin, int64_t end) {
|
parallel_for(0, num_samples, 64,
|
||||||
for (int64_t sample_idx = begin; sample_idx < end; sample_idx++) {
|
[&embedding_features, &grad_data, &grad_stride0, &grad_stride1, &weight_data, &weight_stride0,
|
||||||
|
&weight_stride1, &offset2bag_data, &indices_data, &output_data](index_t begin, index_t end) {
|
||||||
|
for (index_t sample_idx = begin; sample_idx < end; sample_idx++) {
|
||||||
auto bag_idx = offset2bag_data[sample_idx];
|
auto bag_idx = offset2bag_data[sample_idx];
|
||||||
auto embedding_idx = indices_data[sample_idx];
|
auto embedding_idx = indices_data[sample_idx];
|
||||||
|
|
||||||
|
|
@ -864,6 +901,7 @@ Tensor _embedding_bag_per_sample_weights_backward_cpu_template(
|
||||||
weight_data + weight_stride0 * embedding_idx, weight_stride1);
|
weight_data + weight_stride0 * embedding_idx, weight_stride1);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
});
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -381,7 +381,8 @@ Tensor& index_add_cpu_(Tensor & self, int64_t dim, const Tensor & index, const T
|
||||||
|
|
||||||
auto numel = index.numel();
|
auto numel = index.numel();
|
||||||
TORCH_CHECK_INDEX(index.dim() <= 1, "index_add_(): Index is supposed to be a vector");
|
TORCH_CHECK_INDEX(index.dim() <= 1, "index_add_(): Index is supposed to be a vector");
|
||||||
TORCH_CHECK(index.scalar_type() == ScalarType::Long, "index_add_(): Expected dtype int64 for index");
|
TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int,
|
||||||
|
"index_add_(): Expected dtype int32/int64 for index");
|
||||||
TORCH_CHECK(self.scalar_type() == source.scalar_type(),
|
TORCH_CHECK(self.scalar_type() == source.scalar_type(),
|
||||||
"index_add_(): self and source must have the same scalar type");
|
"index_add_(): self and source must have the same scalar type");
|
||||||
TORCH_CHECK(dim == 0 || dim < source.dim(),
|
TORCH_CHECK(dim == 0 || dim < source.dim(),
|
||||||
|
|
@ -394,7 +395,6 @@ Tensor& index_add_cpu_(Tensor & self, int64_t dim, const Tensor & index, const T
|
||||||
at::assert_no_partial_overlap(self, source);
|
at::assert_no_partial_overlap(self, source);
|
||||||
|
|
||||||
auto index_contig = index.contiguous();
|
auto index_contig = index.contiguous();
|
||||||
auto index_data = index_contig.data_ptr<int64_t>();
|
|
||||||
|
|
||||||
if (self.dim() > 1) {
|
if (self.dim() > 1) {
|
||||||
// Equivalent to:
|
// Equivalent to:
|
||||||
|
|
@ -414,6 +414,8 @@ Tensor& index_add_cpu_(Tensor & self, int64_t dim, const Tensor & index, const T
|
||||||
auto self_dim_size = self.size(dim);
|
auto self_dim_size = self.size(dim);
|
||||||
auto iter = TensorIterator::binary_op(selfSlice, selfSlice, sourceSlice);
|
auto iter = TensorIterator::binary_op(selfSlice, selfSlice, sourceSlice);
|
||||||
|
|
||||||
|
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cpu_", [&] () {
|
||||||
|
auto index_data = index_contig.data_ptr<index_t>();
|
||||||
for (auto i = 0; i < numel; i++) {
|
for (auto i = 0; i < numel; i++) {
|
||||||
auto self_i = index_data[i];
|
auto self_i = index_data[i];
|
||||||
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
|
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
|
||||||
|
|
@ -424,16 +426,22 @@ Tensor& index_add_cpu_(Tensor & self, int64_t dim, const Tensor & index, const T
|
||||||
iter.unsafe_replace_operand(2, source_data);
|
iter.unsafe_replace_operand(2, source_data);
|
||||||
add_stub(iter.device_type(), iter, 1);
|
add_stub(iter.device_type(), iter, 1);
|
||||||
}
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
TORCH_CHECK(source.dim() <= 1, "source.dim() (", source.dim(), ") must one or zero for given self.dim() (", self.dim(), ")");
|
TORCH_CHECK(source.dim() <= 1, "source.dim() (", source.dim(), ") must one or zero for given self.dim() (", self.dim(), ")");
|
||||||
|
|
||||||
AT_DISPATCH_ALL_TYPES(self.scalar_type(), "index_add_", [&] {
|
// explicitly capture all required variables to work around windows build
|
||||||
|
// TODO: fix this when windows can correctly capture variables in nested lambda
|
||||||
|
AT_DISPATCH_ALL_TYPES(self.scalar_type(), "index_add_", [&self, &source, &dim, &index_contig, &numel] {
|
||||||
auto self_stride = self.dim() == 0 ? 1 : self.stride(dim);
|
auto self_stride = self.dim() == 0 ? 1 : self.stride(dim);
|
||||||
auto source_stride = source.dim() == 0 ? 1 : source.stride(dim);
|
auto source_stride = source.dim() == 0 ? 1 : source.stride(dim);
|
||||||
// TODO: Maybe TensorAccessor can beused here?
|
// TODO: Maybe TensorAccessor can beused here?
|
||||||
auto* self_ptr = self.data_ptr<scalar_t>();
|
auto* self_ptr = self.data_ptr<scalar_t>();
|
||||||
auto* source_ptr = source.data_ptr<scalar_t>();
|
auto* source_ptr = source.data_ptr<scalar_t>();
|
||||||
|
AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_add_cpu_",
|
||||||
|
[&index_contig, &numel, &self, &self_ptr, &self_stride, &source_ptr, &source_stride] {
|
||||||
|
auto index_data = index_contig.data_ptr<index_t>();
|
||||||
for (auto i = 0; i < numel; i++) {
|
for (auto i = 0; i < numel; i++) {
|
||||||
auto self_i = index_data[i];
|
auto self_i = index_data[i];
|
||||||
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self.numel()), "index out of range in self");
|
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self.numel()), "index out of range in self");
|
||||||
|
|
@ -441,6 +449,7 @@ Tensor& index_add_cpu_(Tensor & self, int64_t dim, const Tensor & index, const T
|
||||||
*self_ip += *(source_ptr + i * source_stride);
|
*self_ip += *(source_ptr + i * source_stride);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
});
|
||||||
}
|
}
|
||||||
return self;
|
return self;
|
||||||
}
|
}
|
||||||
|
|
@ -454,7 +463,7 @@ Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim
|
||||||
|
|
||||||
auto numel = index.numel();
|
auto numel = index.numel();
|
||||||
TORCH_CHECK_INDEX(index.dim() <= 1, "index_select(): Index is supposed to be a vector");
|
TORCH_CHECK_INDEX(index.dim() <= 1, "index_select(): Index is supposed to be a vector");
|
||||||
TORCH_CHECK(index.scalar_type() == ScalarType::Long, "index_select(): Expected dtype int64 for index");
|
TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int, "index_select(): Expected dtype int32 or int64 for index");
|
||||||
TORCH_CHECK(self.scalar_type() == result.scalar_type(),
|
TORCH_CHECK(self.scalar_type() == result.scalar_type(),
|
||||||
"index_select(): self and result must have the same scalar type");
|
"index_select(): self and result must have the same scalar type");
|
||||||
TORCH_CHECK(dim == 0 || dim < self.dim(),
|
TORCH_CHECK(dim == 0 || dim < self.dim(),
|
||||||
|
|
@ -468,7 +477,6 @@ Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim
|
||||||
result.resize_(result_size);
|
result.resize_(result_size);
|
||||||
|
|
||||||
auto index_contig = index.contiguous();
|
auto index_contig = index.contiguous();
|
||||||
auto index_data = index_contig.data_ptr<int64_t>();
|
|
||||||
|
|
||||||
if (self.dim() > 1) {
|
if (self.dim() > 1) {
|
||||||
if (numel == 0 || self.numel() == 0) {
|
if (numel == 0 || self.numel() == 0) {
|
||||||
|
|
@ -492,8 +500,16 @@ Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
auto grain_size = at::internal::GRAIN_SIZE;
|
auto grain_size = at::internal::GRAIN_SIZE;
|
||||||
auto outer_loop = [&](int64_t start, int64_t end) {
|
auto outer_loop =
|
||||||
|
// explicitly capture all required variables to work around windows build
|
||||||
|
// TODO: fix this when windows can correctly capture variables in nested lambda
|
||||||
|
[&index_contig, &iter, &self_dim_size, &selfSlice_data, &self_stride_bytes, &resultSlice_data,
|
||||||
|
&result_stride_bytes](int64_t start, int64_t end) {
|
||||||
auto sub_iter = TensorIterator(iter);
|
auto sub_iter = TensorIterator(iter);
|
||||||
|
AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_",
|
||||||
|
[&index_contig, &start, &end, &sub_iter, &self_dim_size, &selfSlice_data, &self_stride_bytes,
|
||||||
|
&resultSlice_data, &result_stride_bytes] () {
|
||||||
|
auto index_data = index_contig.data_ptr<index_t>();
|
||||||
for (int64_t i = start; i < end; i++) {
|
for (int64_t i = start; i < end; i++) {
|
||||||
auto self_i = index_data[i];
|
auto self_i = index_data[i];
|
||||||
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
|
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
|
||||||
|
|
@ -502,7 +518,8 @@ Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim
|
||||||
sub_iter.unsafe_replace_operand(0, result_data);
|
sub_iter.unsafe_replace_operand(0, result_data);
|
||||||
sub_iter.unsafe_replace_operand(1, self_data);
|
sub_iter.unsafe_replace_operand(1, self_data);
|
||||||
copy_stub(sub_iter.device_type(), sub_iter, false);
|
copy_stub(sub_iter.device_type(), sub_iter, false);
|
||||||
}
|
};
|
||||||
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
// parallel on inner loop in case the slice is large enough;
|
// parallel on inner loop in case the slice is large enough;
|
||||||
|
|
@ -513,7 +530,15 @@ Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim
|
||||||
// use a fast loop when self and result are contiguous and of the same data type
|
// use a fast loop when self and result are contiguous and of the same data type
|
||||||
if (iter.is_contiguous() && self.scalar_type() == result.scalar_type()) {
|
if (iter.is_contiguous() && self.scalar_type() == result.scalar_type()) {
|
||||||
auto slice_size_bytes = slice_size * elementSize(self.scalar_type());
|
auto slice_size_bytes = slice_size * elementSize(self.scalar_type());
|
||||||
at::parallel_for(0, numel, grain_size / slice_size, [&](int64_t start, int64_t end) {
|
// explicitly capture all required variables to work around windows build
|
||||||
|
// TODO: fix this when windows can correctly capture variables in nested lambda
|
||||||
|
at::parallel_for(0, numel, grain_size / slice_size,
|
||||||
|
[&index_contig, &slice_size_bytes, &self_dim_size, &selfSlice_data,
|
||||||
|
&self_stride_bytes, &resultSlice_data, &result_stride_bytes](int64_t start, int64_t end) {
|
||||||
|
AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_",
|
||||||
|
[&index_contig, &slice_size_bytes, &self_dim_size, &selfSlice_data,
|
||||||
|
&self_stride_bytes, &resultSlice_data, &result_stride_bytes, &start, &end] () {
|
||||||
|
auto index_data = index_contig.data_ptr<index_t>();
|
||||||
for (int64_t i = start; i < end; i++) {
|
for (int64_t i = start; i < end; i++) {
|
||||||
auto self_i = index_data[i];
|
auto self_i = index_data[i];
|
||||||
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
|
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
|
||||||
|
|
@ -522,20 +547,26 @@ Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim
|
||||||
memcpy(result_data, self_data, slice_size_bytes);
|
memcpy(result_data, self_data, slice_size_bytes);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
});
|
||||||
} else {
|
} else {
|
||||||
at::parallel_for(0, numel, grain_size / slice_size, outer_loop);
|
at::parallel_for(0, numel, grain_size / slice_size, outer_loop);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(result.dim() <= 1, "result.dim() (", result.dim(), ") must one or zero for given self.dim() (", self.dim(), ")");
|
TORCH_CHECK(result.dim() <= 1, "result.dim() (", result.dim(), ") must one or zero for given self.dim() (", self.dim(), ")");
|
||||||
|
// explicitly capture all required variables to work around windows build
|
||||||
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, self.scalar_type(), "index_select", [&] {
|
// TODO: fix this when windows can correctly capture variables in nested lambda
|
||||||
|
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, self.scalar_type(), "index_select",
|
||||||
|
[&index_contig, &self, &result, &dim, &numel] {
|
||||||
auto self_stride = self.dim() == 0 ? 1 : self.stride(dim);
|
auto self_stride = self.dim() == 0 ? 1 : self.stride(dim);
|
||||||
auto result_stride = result.dim() == 0 ? 1 : result.stride(dim);
|
auto result_stride = result.dim() == 0 ? 1 : result.stride(dim);
|
||||||
|
|
||||||
auto self_data_ptr = self.data_ptr<scalar_t>();
|
auto self_data_ptr = self.data_ptr<scalar_t>();
|
||||||
auto result_data_ptr = result.data_ptr<scalar_t>();
|
auto result_data_ptr = result.data_ptr<scalar_t>();
|
||||||
auto self_numel = self.numel();
|
auto self_numel = self.numel();
|
||||||
|
AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_",
|
||||||
|
[&index_contig, &numel, &self_numel, &self_data_ptr, &self_stride, &result_data_ptr, &result_stride] {
|
||||||
|
auto index_data = index_contig.data_ptr<index_t>();
|
||||||
for (auto i = 0; i < numel; i++) {
|
for (auto i = 0; i < numel; i++) {
|
||||||
auto self_i = index_data[i];
|
auto self_i = index_data[i];
|
||||||
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_numel), "index out of range in self");
|
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_numel), "index out of range in self");
|
||||||
|
|
@ -543,6 +574,7 @@ Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim
|
||||||
*(result_data_ptr + i * result_stride) = *self_ip;
|
*(result_data_ptr + i * result_stride) = *self_ip;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
|
|
||||||
|
|
@ -29,9 +29,10 @@ static const int BLOCKDIMY = 32;
|
||||||
|
|
||||||
template
|
template
|
||||||
<typename scalar_t,
|
<typename scalar_t,
|
||||||
typename accscalar_t>
|
typename accscalar_t,
|
||||||
|
typename index_t>
|
||||||
__global__ void embedding_backward_feature_kernel
|
__global__ void embedding_backward_feature_kernel
|
||||||
(int64_t* indices,
|
(index_t* indices,
|
||||||
const scalar_t* __restrict__ grad,
|
const scalar_t* __restrict__ grad,
|
||||||
scalar_t* __restrict__ grad_weight,
|
scalar_t* __restrict__ grad_weight,
|
||||||
int n, // OK to pass as int, we don't expect 2 billion+ samples in one shot
|
int n, // OK to pass as int, we don't expect 2 billion+ samples in one shot
|
||||||
|
|
@ -117,10 +118,10 @@ __global__ void embedding_backward_feature_kernel
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t, typename index_t>
|
||||||
__global__ void embedding_backward_kernel(
|
__global__ void embedding_backward_kernel(
|
||||||
int64_t* input, int64_t* indices, scalar_t* grad_output, scalar_t* grad_weight,
|
index_t* input, index_t* indices, scalar_t* grad_output, scalar_t* grad_weight,
|
||||||
int64_t* count, int64_t numel, int64_t stride, int padding_idx) {
|
index_t* count, int64_t numel, int64_t stride, int padding_idx) {
|
||||||
|
|
||||||
using accscalar_t = acc_type<scalar_t, true>;
|
using accscalar_t = acc_type<scalar_t, true>;
|
||||||
int idx = blockIdx.x * 4 + threadIdx.y;
|
int idx = blockIdx.x * 4 + threadIdx.y;
|
||||||
|
|
@ -179,9 +180,9 @@ __global__ void embedding_backward_kernel(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Calculate norms of the rows of weight_ptr given by idx_ptr and capture them in norms */
|
/* Calculate norms of the rows of weight_ptr given by idx_ptr and capture them in norms */
|
||||||
template <typename scalar_t, typename accscalar_t>
|
template <typename scalar_t, typename accscalar_t, typename index_t>
|
||||||
__global__ void renorm_kernel(
|
__global__ void renorm_kernel(
|
||||||
scalar_t* weights, int64_t* indices, accscalar_t max_norm,
|
scalar_t* weights, index_t* indices, accscalar_t max_norm,
|
||||||
accscalar_t norm_type, int64_t dim,
|
accscalar_t norm_type, int64_t dim,
|
||||||
int64_t weights_stride0, int64_t weights_stride1) {
|
int64_t weights_stride0, int64_t weights_stride1) {
|
||||||
|
|
||||||
|
|
@ -228,7 +229,7 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
|
||||||
bool scale_grad_by_freq) {
|
bool scale_grad_by_freq) {
|
||||||
auto grad_arg = TensorArg(grad_, "grad", 1);
|
auto grad_arg = TensorArg(grad_, "grad", 1);
|
||||||
auto indices_arg = TensorArg(indices, "indices", 1);
|
auto indices_arg = TensorArg(indices, "indices", 1);
|
||||||
checkScalarType("embedding_backward", indices_arg, kLong);
|
checkScalarTypes("embedding_backward", indices_arg, {kLong, kInt});
|
||||||
checkSameGPU("embedding_backward", grad_arg, indices_arg);
|
checkSameGPU("embedding_backward", grad_arg, indices_arg);
|
||||||
|
|
||||||
auto num_indices = indices.numel();
|
auto num_indices = indices.numel();
|
||||||
|
|
@ -250,12 +251,13 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
|
||||||
{
|
{
|
||||||
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "embedding_backward", [&] {
|
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "embedding_backward", [&] {
|
||||||
using accscalar_t = acc_type<scalar_t, true>;
|
using accscalar_t = acc_type<scalar_t, true>;
|
||||||
embedding_backward_feature_kernel<scalar_t, accscalar_t>
|
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
|
||||||
|
embedding_backward_feature_kernel<scalar_t, accscalar_t, index_t>
|
||||||
<<<grid,
|
<<<grid,
|
||||||
block,
|
block,
|
||||||
sizeof(accscalar_t)*C10_WARP_SIZE*BLOCKDIMY + sizeof(int)*C10_WARP_SIZE*BLOCKDIMY,
|
sizeof(accscalar_t)*C10_WARP_SIZE*BLOCKDIMY + sizeof(int)*C10_WARP_SIZE*BLOCKDIMY,
|
||||||
stream>>>
|
stream>>>
|
||||||
(indices_contig.data_ptr<int64_t>(),
|
(indices_contig.data_ptr<index_t>(),
|
||||||
grad.data_ptr<scalar_t>(),
|
grad.data_ptr<scalar_t>(),
|
||||||
grad_weight.data_ptr<scalar_t>(),
|
grad_weight.data_ptr<scalar_t>(),
|
||||||
static_cast<int>(num_indices),
|
static_cast<int>(num_indices),
|
||||||
|
|
@ -264,12 +266,15 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
|
||||||
TORCH_CUDA_KERNEL_LAUNCH_CHECK();
|
TORCH_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
});
|
||||||
return grad_weight;
|
return grad_weight;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto sorted_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
auto sorted_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||||
auto orig_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
auto orig_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||||
using device_ptr = thrust::device_ptr<int64_t>;
|
Tensor count;
|
||||||
|
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
|
||||||
|
using device_ptr = thrust::device_ptr<index_t>;
|
||||||
|
|
||||||
// Sort the inputs into sorted with the corresponding indices; we
|
// Sort the inputs into sorted with the corresponding indices; we
|
||||||
// don't need a stable or multidimensional sort, so just use Thrust
|
// don't need a stable or multidimensional sort, so just use Thrust
|
||||||
|
|
@ -281,17 +286,16 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
|
||||||
auto policy = thrust::cuda::par(allocator).on(stream);
|
auto policy = thrust::cuda::par(allocator).on(stream);
|
||||||
|
|
||||||
// Fill sortedOrigIndices with sequential indices
|
// Fill sortedOrigIndices with sequential indices
|
||||||
auto count_iter = thrust::counting_iterator<int64_t>(0);
|
auto count_iter = thrust::counting_iterator<index_t>(0);
|
||||||
auto orig_data = device_ptr(orig_indices.data_ptr<int64_t>());
|
auto orig_data = device_ptr(orig_indices.data_ptr<index_t>());
|
||||||
thrust::copy(policy, count_iter, count_iter + num_indices, orig_data);
|
thrust::copy(policy, count_iter, count_iter + num_indices, orig_data);
|
||||||
|
|
||||||
// Sort; a stable sort is not required
|
// Sort; a stable sort is not required
|
||||||
auto sorted_data = device_ptr(sorted_indices.data_ptr<int64_t>());
|
auto sorted_data = device_ptr(sorted_indices.data_ptr<index_t>());
|
||||||
thrust::sort_by_key(policy, sorted_data, sorted_data + num_indices, orig_data,
|
thrust::sort_by_key(policy, sorted_data, sorted_data + num_indices, orig_data,
|
||||||
ThrustLTOp<int64_t>());
|
ThrustLTOp<index_t>());
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor count;
|
|
||||||
if (scale_grad_by_freq) {
|
if (scale_grad_by_freq) {
|
||||||
count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||||
|
|
||||||
|
|
@ -301,8 +305,8 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
|
||||||
// Compute an increasing sequence per unique item in sortedIndices:
|
// Compute an increasing sequence per unique item in sortedIndices:
|
||||||
// sorted: 2 5 5 5 7 7 8 9 9
|
// sorted: 2 5 5 5 7 7 8 9 9
|
||||||
// count: 1 1 2 3 1 2 1 1 2
|
// count: 1 1 2 3 1 2 1 1 2
|
||||||
auto sorted_data = device_ptr(sorted_indices.data_ptr<int64_t>());
|
auto sorted_data = device_ptr(sorted_indices.data_ptr<index_t>());
|
||||||
auto count_data = device_ptr(count.data_ptr<int64_t>());
|
auto count_data = device_ptr(count.data_ptr<index_t>());
|
||||||
thrust::inclusive_scan_by_key(
|
thrust::inclusive_scan_by_key(
|
||||||
policy,
|
policy,
|
||||||
sorted_data,
|
sorted_data,
|
||||||
|
|
@ -320,10 +324,11 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
|
||||||
thrust::make_reverse_iterator(sorted_data),
|
thrust::make_reverse_iterator(sorted_data),
|
||||||
thrust::make_reverse_iterator(count_data + num_indices),
|
thrust::make_reverse_iterator(count_data + num_indices),
|
||||||
thrust::make_reverse_iterator(count_data + num_indices),
|
thrust::make_reverse_iterator(count_data + num_indices),
|
||||||
thrust::equal_to<int64_t>(),
|
thrust::equal_to<index_t>(),
|
||||||
thrust::maximum<int64_t>()
|
thrust::maximum<index_t>()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
});
|
||||||
|
|
||||||
return embedding_backward_cuda_kernel(grad, orig_indices,
|
return embedding_backward_cuda_kernel(grad, orig_indices,
|
||||||
sorted_indices, count, num_weights, padding_idx);
|
sorted_indices, count, num_weights, padding_idx);
|
||||||
|
|
@ -340,14 +345,15 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices,
|
||||||
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
|
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
|
||||||
auto policy = thrust::cuda::par(allocator).on(stream);
|
auto policy = thrust::cuda::par(allocator).on(stream);
|
||||||
|
|
||||||
using device_ptr = thrust::device_ptr<int64_t>;
|
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_renorm_cuda_", [&] () {
|
||||||
|
using device_ptr = thrust::device_ptr<index_t>;
|
||||||
|
|
||||||
auto num_indices = indices.numel();
|
auto num_indices = indices.numel();
|
||||||
auto indices_contig = std::get<0>(indices.sort()).contiguous();
|
auto indices_contig = std::get<0>(indices.sort()).contiguous();
|
||||||
auto indices_data = device_ptr(indices_contig.data_ptr<int64_t>());
|
auto indices_data = device_ptr(indices_contig.data_ptr<index_t>());
|
||||||
|
|
||||||
auto unique_indices = at::empty(indices.numel(), indices.options());
|
auto unique_indices = at::empty(indices.numel(), indices.options());
|
||||||
auto unique_data = device_ptr(unique_indices.data_ptr<int64_t>());
|
auto unique_data = device_ptr(unique_indices.data_ptr<index_t>());
|
||||||
auto end = thrust::unique_copy(policy, indices_data, indices_data + num_indices, unique_data);
|
auto end = thrust::unique_copy(policy, indices_data, indices_data + num_indices, unique_data);
|
||||||
auto num_unique_indices = static_cast<int>(end - unique_data);
|
auto num_unique_indices = static_cast<int>(end - unique_data);
|
||||||
|
|
||||||
|
|
@ -360,13 +366,14 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices,
|
||||||
using accscalar_t = acc_type<scalar_t, true>;
|
using accscalar_t = acc_type<scalar_t, true>;
|
||||||
renorm_kernel<<<grid, block, 128 * sizeof(accscalar_t), stream>>>(
|
renorm_kernel<<<grid, block, 128 * sizeof(accscalar_t), stream>>>(
|
||||||
self.data_ptr<scalar_t>(),
|
self.data_ptr<scalar_t>(),
|
||||||
unique_indices.data_ptr<int64_t>(),
|
unique_indices.data_ptr<index_t>(),
|
||||||
static_cast<accscalar_t>(max_norm),
|
static_cast<accscalar_t>(max_norm),
|
||||||
static_cast<accscalar_t>(norm_type),
|
static_cast<accscalar_t>(norm_type),
|
||||||
dim, self.stride(0), self.stride(1));
|
dim, self.stride(0), self.stride(1));
|
||||||
TORCH_CUDA_KERNEL_LAUNCH_CHECK();
|
TORCH_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
});
|
||||||
return self;
|
return self;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -40,8 +40,9 @@ int64_t ceil_div(int64_t x, int64_t y) {
|
||||||
return (x + y - 1) / y;
|
return (x + y - 1) / y;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename index_t>
|
||||||
__global__
|
__global__
|
||||||
void krn_partials_per_segment(int64_t *ret, const int64_t *segment_offsets,
|
void krn_partials_per_segment(index_t *ret, const index_t *segment_offsets,
|
||||||
int64_t num_of_segments, int64_t numel) {
|
int64_t num_of_segments, int64_t numel) {
|
||||||
const int id = blockIdx.x * blockDim.x + threadIdx.x;
|
const int id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
if(id < num_of_segments) {
|
if(id < num_of_segments) {
|
||||||
|
|
@ -52,18 +53,19 @@ void krn_partials_per_segment(int64_t *ret, const int64_t *segment_offsets,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename index_t>
|
||||||
__global__
|
__global__
|
||||||
void krn_partial_segment_offset(
|
void krn_partial_segment_offset(
|
||||||
int64_t *ret,
|
index_t *ret,
|
||||||
const int64_t *partials_per_segment,
|
const index_t *partials_per_segment,
|
||||||
const int64_t *partials_per_segment_offset,
|
const index_t *partials_per_segment_offset,
|
||||||
const int64_t *segment_offsets,
|
const index_t *segment_offsets,
|
||||||
int64_t num_of_segments) {
|
int64_t num_of_segments) {
|
||||||
const int id = blockIdx.x * blockDim.x + threadIdx.x;
|
const int id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
if(id < num_of_segments) {
|
if(id < num_of_segments) {
|
||||||
int64_t idx = partials_per_segment_offset[id];
|
index_t idx = partials_per_segment_offset[id];
|
||||||
const int64_t num_partials = partials_per_segment[id];
|
const index_t num_partials = partials_per_segment[id];
|
||||||
const int64_t segment_offset = segment_offsets[id];
|
const index_t segment_offset = segment_offsets[id];
|
||||||
for (int64_t i=0; i<num_partials; ++i) {
|
for (int64_t i=0; i<num_partials; ++i) {
|
||||||
ret[idx++] = segment_offset + i * NROWS_PER_THREAD;
|
ret[idx++] = segment_offset + i * NROWS_PER_THREAD;
|
||||||
}
|
}
|
||||||
|
|
@ -71,13 +73,13 @@ void krn_partial_segment_offset(
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t, typename index_t>
|
||||||
__global__ void compute_grad_weight_bags(
|
__global__ void compute_grad_weight_bags(
|
||||||
int64_t *indices, scalar_t *gradOutput,
|
index_t *indices, scalar_t *gradOutput,
|
||||||
int64_t *offset2bag, int64_t *count, ptrdiff_t numel,
|
index_t *offset2bag, index_t *count, ptrdiff_t numel,
|
||||||
int64_t stride, int mode_mean, const int64_t *bag_size,
|
int64_t stride, int mode_mean, const index_t *bag_size,
|
||||||
scalar_t* per_sample_weights, int64_t per_sample_weights_stride,
|
scalar_t* per_sample_weights, int64_t per_sample_weights_stride,
|
||||||
int64_t* segment_offsets, int64_t num_of_segments,
|
index_t* segment_offsets, int64_t num_of_segments,
|
||||||
acc_type<scalar_t, true> *grad_weight_per_segment,
|
acc_type<scalar_t, true> *grad_weight_per_segment,
|
||||||
const int64_t stride_warped) {
|
const int64_t stride_warped) {
|
||||||
|
|
||||||
|
|
@ -113,14 +115,14 @@ __global__ void compute_grad_weight_bags(
|
||||||
grad_weight_per_segment[id * stride + startFeature] = weight;
|
grad_weight_per_segment[id * stride + startFeature] = weight;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t, typename index_t>
|
||||||
__global__ void compute_grad_weight(
|
__global__ void compute_grad_weight(
|
||||||
int64_t *indices,
|
index_t *indices,
|
||||||
scalar_t *gradOutput,
|
scalar_t *gradOutput,
|
||||||
int64_t *count,
|
index_t *count,
|
||||||
ptrdiff_t numel,
|
ptrdiff_t numel,
|
||||||
int64_t stride,
|
int64_t stride,
|
||||||
int64_t* segment_offsets,
|
index_t* segment_offsets,
|
||||||
int64_t num_of_segments,
|
int64_t num_of_segments,
|
||||||
acc_type<scalar_t, true> *grad_weight_per_segment,
|
acc_type<scalar_t, true> *grad_weight_per_segment,
|
||||||
const int64_t stride_warped) {
|
const int64_t stride_warped) {
|
||||||
|
|
@ -140,7 +142,7 @@ __global__ void compute_grad_weight(
|
||||||
|
|
||||||
accscalar_t weight = 0;
|
accscalar_t weight = 0;
|
||||||
for (int idx=idx_begin; idx < idx_end; ++idx) {
|
for (int idx=idx_begin; idx < idx_end; ++idx) {
|
||||||
const int64_t target_row = indices[idx];
|
const index_t target_row = indices[idx];
|
||||||
const accscalar_t scale = count ? (accscalar_t)1.0 / count[idx] : 1.0;
|
const accscalar_t scale = count ? (accscalar_t)1.0 / count[idx] : 1.0;
|
||||||
weight += gradOutput[target_row * stride + startFeature] * scale;
|
weight += gradOutput[target_row * stride + startFeature] * scale;
|
||||||
}
|
}
|
||||||
|
|
@ -148,12 +150,12 @@ __global__ void compute_grad_weight(
|
||||||
}
|
}
|
||||||
|
|
||||||
// This kernel assumes that all input tensors are contiguous.
|
// This kernel assumes that all input tensors are contiguous.
|
||||||
template <typename scalar_t>
|
template <typename scalar_t, typename index_t>
|
||||||
__global__ void sum_and_scatter(
|
__global__ void sum_and_scatter(
|
||||||
int64_t *input, scalar_t *gradWeight, int64_t stride,
|
index_t *input, scalar_t *gradWeight, int64_t stride,
|
||||||
int64_t* segment_offsets, int64_t num_of_segments,
|
index_t* segment_offsets, int64_t num_of_segments,
|
||||||
const acc_type<scalar_t, true> *grad_weight_per_segment,
|
const acc_type<scalar_t, true> *grad_weight_per_segment,
|
||||||
const int64_t *segment_sizes_offsets, int64_t num_of_partial_segments,
|
const index_t *segment_sizes_offsets, int64_t num_of_partial_segments,
|
||||||
const int64_t padding_idx,
|
const int64_t padding_idx,
|
||||||
const int64_t stride_warped) {
|
const int64_t stride_warped) {
|
||||||
|
|
||||||
|
|
@ -206,19 +208,20 @@ Tensor embedding_backward_cuda_kernel(
|
||||||
// spawn a warp per index. In this context, a segment is a number of rows that should
|
// spawn a warp per index. In this context, a segment is a number of rows that should
|
||||||
// be summarized.
|
// be summarized.
|
||||||
// Unit: index in `sorted_indices` and `orig_indices`
|
// Unit: index in `sorted_indices` and `orig_indices`
|
||||||
|
AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () {
|
||||||
auto segment_offsets = at::empty({numel}, orig_indices.options());
|
auto segment_offsets = at::empty({numel}, orig_indices.options());
|
||||||
int64_t num_of_segments;
|
int64_t num_of_segments;
|
||||||
{
|
{
|
||||||
auto sorted_indices_dev = thrust::device_ptr<int64_t>(sorted_indices.data_ptr<int64_t>());
|
auto sorted_indices_dev = thrust::device_ptr<index_t>(sorted_indices.data_ptr<index_t>());
|
||||||
auto dummy = at::empty_like(sorted_indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
auto dummy = at::empty_like(sorted_indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||||
auto dummy_dev = thrust::device_ptr<int64_t>(dummy.data_ptr<int64_t>());
|
auto dummy_dev = thrust::device_ptr<index_t>(dummy.data_ptr<index_t>());
|
||||||
auto ends = thrust::unique_by_key_copy(
|
auto ends = thrust::unique_by_key_copy(
|
||||||
policy,
|
policy,
|
||||||
sorted_indices_dev,
|
sorted_indices_dev,
|
||||||
sorted_indices_dev + numel,
|
sorted_indices_dev + numel,
|
||||||
thrust::make_counting_iterator(0),
|
thrust::make_counting_iterator(0),
|
||||||
dummy_dev,
|
dummy_dev,
|
||||||
thrust::device_ptr<int64_t>(segment_offsets.data_ptr<int64_t>()));
|
thrust::device_ptr<index_t>(segment_offsets.data_ptr<index_t>()));
|
||||||
num_of_segments = thrust::get<0>(ends) - dummy_dev;
|
num_of_segments = thrust::get<0>(ends) - dummy_dev;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -228,8 +231,8 @@ Tensor embedding_backward_cuda_kernel(
|
||||||
auto partials_per_segment = at::empty({num_of_segments}, orig_indices.options());
|
auto partials_per_segment = at::empty({num_of_segments}, orig_indices.options());
|
||||||
{
|
{
|
||||||
krn_partials_per_segment<<<ceil_div(num_of_segments, 32), 32, 0, stream>>> (
|
krn_partials_per_segment<<<ceil_div(num_of_segments, 32), 32, 0, stream>>> (
|
||||||
partials_per_segment.data_ptr<int64_t>(),
|
partials_per_segment.data_ptr<index_t>(),
|
||||||
segment_offsets.data_ptr<int64_t>(),
|
segment_offsets.data_ptr<index_t>(),
|
||||||
num_of_segments,
|
num_of_segments,
|
||||||
numel);
|
numel);
|
||||||
}
|
}
|
||||||
|
|
@ -241,23 +244,23 @@ Tensor embedding_backward_cuda_kernel(
|
||||||
auto partials_per_segment_offset = at::empty({num_of_segments}, orig_indices.options());
|
auto partials_per_segment_offset = at::empty({num_of_segments}, orig_indices.options());
|
||||||
thrust::exclusive_scan(
|
thrust::exclusive_scan(
|
||||||
policy,
|
policy,
|
||||||
thrust::device_ptr<int64_t>(partials_per_segment.data_ptr<int64_t>()),
|
thrust::device_ptr<index_t>(partials_per_segment.data_ptr<index_t>()),
|
||||||
thrust::device_ptr<int64_t>(partials_per_segment.data_ptr<int64_t>()+num_of_segments),
|
thrust::device_ptr<index_t>(partials_per_segment.data_ptr<index_t>()+num_of_segments),
|
||||||
thrust::device_ptr<int64_t>(partials_per_segment_offset.data_ptr<int64_t>()));
|
thrust::device_ptr<index_t>(partials_per_segment_offset.data_ptr<index_t>()));
|
||||||
|
|
||||||
// The total number of partial-segments is the sum of `partials_per_segment_offset`
|
// The total number of partial-segments is the sum of `partials_per_segment_offset`
|
||||||
const int num_of_partial_segments = partials_per_segment[num_of_segments-1].item<int64_t>() +
|
const int num_of_partial_segments = partials_per_segment[num_of_segments-1].item<index_t>() +
|
||||||
partials_per_segment_offset[num_of_segments-1].item<int64_t>();
|
partials_per_segment_offset[num_of_segments-1].item<index_t>();
|
||||||
|
|
||||||
// Now we can compute the start position of each partial-segment
|
// Now we can compute the start position of each partial-segment
|
||||||
// Unit: index in `sorted_indices` and `orig_indices`
|
// Unit: index in `sorted_indices` and `orig_indices`
|
||||||
auto partial_segment_offset = at::empty({num_of_partial_segments}, orig_indices.options());
|
auto partial_segment_offset = at::empty({num_of_partial_segments}, orig_indices.options());
|
||||||
{
|
{
|
||||||
krn_partial_segment_offset<<<ceil_div(num_of_segments, 32), 32, 0, stream>>> (
|
krn_partial_segment_offset<<<ceil_div(num_of_segments, 32), 32, 0, stream>>> (
|
||||||
partial_segment_offset.data_ptr<int64_t>(),
|
partial_segment_offset.data_ptr<index_t>(),
|
||||||
partials_per_segment.data_ptr<int64_t>(),
|
partials_per_segment.data_ptr<index_t>(),
|
||||||
partials_per_segment_offset.data_ptr<int64_t>(),
|
partials_per_segment_offset.data_ptr<index_t>(),
|
||||||
segment_offsets.data_ptr<int64_t>(),
|
segment_offsets.data_ptr<index_t>(),
|
||||||
num_of_segments);
|
num_of_segments);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -281,23 +284,23 @@ Tensor embedding_backward_cuda_kernel(
|
||||||
// Compute the sum of each partial-segment and handle bags
|
// Compute the sum of each partial-segment and handle bags
|
||||||
if (offset2bag.defined()) {
|
if (offset2bag.defined()) {
|
||||||
compute_grad_weight_bags<scalar_t><<<grid, block, 0, stream>>>(
|
compute_grad_weight_bags<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
orig_indices.data_ptr<int64_t>(),
|
orig_indices.data_ptr<index_t>(),
|
||||||
grad.data_ptr<scalar_t>(),
|
grad.data_ptr<scalar_t>(),
|
||||||
offset2bag.data_ptr<int64_t>(),
|
offset2bag.data_ptr<index_t>(),
|
||||||
count.defined() ? count.data_ptr<int64_t>() : nullptr, numel, stride,
|
count.defined() ? count.data_ptr<index_t>() : nullptr, numel, stride,
|
||||||
mode_mean, bag_size.data_ptr<int64_t>(),
|
mode_mean, bag_size.data_ptr<index_t>(),
|
||||||
per_sample_weights.defined() ? per_sample_weights.data_ptr<scalar_t>() : NULL,
|
per_sample_weights.defined() ? per_sample_weights.data_ptr<scalar_t>() : NULL,
|
||||||
per_sample_weights.defined() ? per_sample_weights.stride(0) : 0,
|
per_sample_weights.defined() ? per_sample_weights.stride(0) : 0,
|
||||||
partial_segment_offset.data_ptr<int64_t>(),
|
partial_segment_offset.data_ptr<index_t>(),
|
||||||
num_of_partial_segments, grad_weight_per_segment.data_ptr<partial_weight_t>(),
|
num_of_partial_segments, grad_weight_per_segment.data_ptr<partial_weight_t>(),
|
||||||
stride_warped);
|
stride_warped);
|
||||||
} else {
|
} else {
|
||||||
compute_grad_weight<scalar_t><<<grid, block, 0, stream>>>(
|
compute_grad_weight<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
orig_indices.data_ptr<int64_t>(),
|
orig_indices.data_ptr<index_t>(),
|
||||||
grad.data_ptr<scalar_t>(),
|
grad.data_ptr<scalar_t>(),
|
||||||
count.defined() ? count.data_ptr<int64_t>() : nullptr,
|
count.defined() ? count.data_ptr<index_t>() : nullptr,
|
||||||
numel, stride,
|
numel, stride,
|
||||||
partial_segment_offset.data_ptr<int64_t>(),
|
partial_segment_offset.data_ptr<index_t>(),
|
||||||
num_of_partial_segments,
|
num_of_partial_segments,
|
||||||
grad_weight_per_segment.data_ptr<partial_weight_t>(),
|
grad_weight_per_segment.data_ptr<partial_weight_t>(),
|
||||||
stride_warped);
|
stride_warped);
|
||||||
|
|
@ -308,18 +311,19 @@ Tensor embedding_backward_cuda_kernel(
|
||||||
// into `grad_weight`.
|
// into `grad_weight`.
|
||||||
const int grid2 = ceil_div(num_of_segments*stride_warped, block);
|
const int grid2 = ceil_div(num_of_segments*stride_warped, block);
|
||||||
sum_and_scatter<scalar_t><<<grid2, block, 0, stream>>>(
|
sum_and_scatter<scalar_t><<<grid2, block, 0, stream>>>(
|
||||||
sorted_indices.data_ptr<int64_t>(),
|
sorted_indices.data_ptr<index_t>(),
|
||||||
grad_weight.data_ptr<scalar_t>(),
|
grad_weight.data_ptr<scalar_t>(),
|
||||||
stride,
|
stride,
|
||||||
segment_offsets.data_ptr<int64_t>(),
|
segment_offsets.data_ptr<index_t>(),
|
||||||
num_of_segments, grad_weight_per_segment.data_ptr<partial_weight_t>(),
|
num_of_segments, grad_weight_per_segment.data_ptr<partial_weight_t>(),
|
||||||
partials_per_segment_offset.data_ptr<int64_t>(),
|
partials_per_segment_offset.data_ptr<index_t>(),
|
||||||
num_of_partial_segments,
|
num_of_partial_segments,
|
||||||
padding_idx,
|
padding_idx,
|
||||||
stride_warped);
|
stride_warped);
|
||||||
AT_CUDA_CHECK(cudaGetLastError());
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
});
|
||||||
return grad_weight;
|
return grad_weight;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -31,12 +31,12 @@ constexpr int MODE_MAX = 2;
|
||||||
|
|
||||||
// This kernel assumes that all input tensors except `weight` and
|
// This kernel assumes that all input tensors except `weight` and
|
||||||
// per_sample_weights are contiguous.
|
// per_sample_weights are contiguous.
|
||||||
template <typename scalar_t>
|
template <typename scalar_t, typename index_t>
|
||||||
__global__ void EmbeddingBag_updateOutputKernel(
|
__global__ void EmbeddingBag_updateOutputKernel(
|
||||||
int64_t *input, int64_t *offsets, scalar_t *weight, scalar_t *output,
|
index_t *input, index_t *offsets, scalar_t *weight, scalar_t *output,
|
||||||
int64_t *offset2bag, int64_t numIndices, int64_t numBags,
|
index_t *offset2bag, int64_t numIndices, int64_t numBags,
|
||||||
int64_t featureSize, int64_t weight_stride0, int64_t weight_stride1,
|
int64_t featureSize, int64_t weight_stride0, int64_t weight_stride1,
|
||||||
int mode, int64_t *bag_size, int64_t *max_indices,
|
int mode, index_t *bag_size, index_t *max_indices,
|
||||||
scalar_t* per_sample_weights, int64_t per_sample_weights_stride) {
|
scalar_t* per_sample_weights, int64_t per_sample_weights_stride) {
|
||||||
|
|
||||||
// the strategy here is that each bag x feature is handled by a single thread
|
// the strategy here is that each bag x feature is handled by a single thread
|
||||||
|
|
@ -135,7 +135,10 @@ Tensor embedding_bag_backward_cuda_sum_avg(
|
||||||
|
|
||||||
auto sorted_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
auto sorted_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||||
auto orig_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
auto orig_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||||
using device_ptr = thrust::device_ptr<int64_t>;
|
Tensor count;
|
||||||
|
|
||||||
|
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () {
|
||||||
|
using device_ptr = thrust::device_ptr<index_t>;
|
||||||
|
|
||||||
// Sort the inputs into sorted with the corresponding indices; we
|
// Sort the inputs into sorted with the corresponding indices; we
|
||||||
// don't need a stable or multidimensional sort, so just use Thrust
|
// don't need a stable or multidimensional sort, so just use Thrust
|
||||||
|
|
@ -147,17 +150,16 @@ Tensor embedding_bag_backward_cuda_sum_avg(
|
||||||
auto policy = thrust::cuda::par(allocator).on(stream);
|
auto policy = thrust::cuda::par(allocator).on(stream);
|
||||||
|
|
||||||
// Fill sortedOrigIndices with sequential indices
|
// Fill sortedOrigIndices with sequential indices
|
||||||
auto count_iter = thrust::counting_iterator<int64_t>(0);
|
auto count_iter = thrust::counting_iterator<index_t>(0);
|
||||||
auto orig_data = device_ptr(orig_indices.data_ptr<int64_t>());
|
auto orig_data = device_ptr(orig_indices.data_ptr<index_t>());
|
||||||
thrust::copy(policy, count_iter, count_iter + numel, orig_data);
|
thrust::copy(policy, count_iter, count_iter + numel, orig_data);
|
||||||
|
|
||||||
// Sort; a stable sort is not required
|
// Sort; a stable sort is not required
|
||||||
auto sorted_data = device_ptr(sorted_indices.data_ptr<int64_t>());
|
auto sorted_data = device_ptr(sorted_indices.data_ptr<index_t>());
|
||||||
thrust::sort_by_key(policy, sorted_data, sorted_data + numel, orig_data,
|
thrust::sort_by_key(policy, sorted_data, sorted_data + numel, orig_data,
|
||||||
ThrustLTOp<int64_t>());
|
ThrustLTOp<index_t>());
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor count;
|
|
||||||
if (scale_grad_by_freq) {
|
if (scale_grad_by_freq) {
|
||||||
count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||||
|
|
||||||
|
|
@ -167,8 +169,8 @@ Tensor embedding_bag_backward_cuda_sum_avg(
|
||||||
// Compute an increasing sequence per unique item in sortedIndices:
|
// Compute an increasing sequence per unique item in sortedIndices:
|
||||||
// sorted: 2 5 5 5 7 7 8 9 9
|
// sorted: 2 5 5 5 7 7 8 9 9
|
||||||
// count: 1 1 2 3 1 2 1 1 2
|
// count: 1 1 2 3 1 2 1 1 2
|
||||||
auto sorted_data = device_ptr(sorted_indices.data_ptr<int64_t>());
|
auto sorted_data = device_ptr(sorted_indices.data_ptr<index_t>());
|
||||||
auto count_data = device_ptr(count.data_ptr<int64_t>());
|
auto count_data = device_ptr(count.data_ptr<index_t>());
|
||||||
thrust::inclusive_scan_by_key(policy, sorted_data, sorted_data + numel,
|
thrust::inclusive_scan_by_key(policy, sorted_data, sorted_data + numel,
|
||||||
thrust::make_constant_iterator(1),
|
thrust::make_constant_iterator(1),
|
||||||
count_data);
|
count_data);
|
||||||
|
|
@ -181,16 +183,17 @@ Tensor embedding_bag_backward_cuda_sum_avg(
|
||||||
thrust::make_reverse_iterator(sorted_data),
|
thrust::make_reverse_iterator(sorted_data),
|
||||||
thrust::make_reverse_iterator(count_data + numel),
|
thrust::make_reverse_iterator(count_data + numel),
|
||||||
thrust::make_reverse_iterator(count_data + numel),
|
thrust::make_reverse_iterator(count_data + numel),
|
||||||
thrust::equal_to<int64_t>(), thrust::maximum<int64_t>());
|
thrust::equal_to<index_t>(), thrust::maximum<index_t>());
|
||||||
}
|
}
|
||||||
|
});
|
||||||
return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices,
|
return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices,
|
||||||
count, num_weights, /* padding_idx= */ -1, scale_grad_by_freq,
|
count, num_weights, /* padding_idx= */ -1, scale_grad_by_freq,
|
||||||
mode == MODE_MEAN, offset2bag, bag_size, per_sample_weights);
|
mode == MODE_MEAN, offset2bag, bag_size, per_sample_weights);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t, typename index_t>
|
||||||
__global__ void EmbeddingBag_accGradParametersKernel_max(
|
__global__ void EmbeddingBag_accGradParametersKernel_max(
|
||||||
int64_t *max_indices, scalar_t *gradOutput,
|
index_t *max_indices, scalar_t *gradOutput,
|
||||||
scalar_t *gradWeight, int64_t stride, int64_t numBags) {
|
scalar_t *gradWeight, int64_t stride, int64_t numBags) {
|
||||||
|
|
||||||
using accscalar_t = acc_type<scalar_t, true>;
|
using accscalar_t = acc_type<scalar_t, true>;
|
||||||
|
|
@ -205,7 +208,7 @@ __global__ void EmbeddingBag_accGradParametersKernel_max(
|
||||||
if (featureDim < stride) {
|
if (featureDim < stride) {
|
||||||
int64_t bag = chunk / chunksPerBag;
|
int64_t bag = chunk / chunksPerBag;
|
||||||
|
|
||||||
int64_t word_idx = max_indices[bag * stride + featureDim];
|
index_t word_idx = max_indices[bag * stride + featureDim];
|
||||||
if (word_idx >= 0) {
|
if (word_idx >= 0) {
|
||||||
// If bag is empty, we have max_indices[idx] set to -1 in forward.
|
// If bag is empty, we have max_indices[idx] set to -1 in forward.
|
||||||
gpuAtomicAdd(&(gradWeight[word_idx * stride + featureDim]),
|
gpuAtomicAdd(&(gradWeight[word_idx * stride + featureDim]),
|
||||||
|
|
@ -236,11 +239,13 @@ Tensor embedding_bag_backward_cuda_max(const Tensor &grad,
|
||||||
|
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||||
grad.scalar_type(), "embedding_bag_backward_cuda_max", [&] {
|
grad.scalar_type(), "embedding_bag_backward_cuda_max", [&] {
|
||||||
|
AT_DISPATCH_INDEX_TYPES(max_indices.scalar_type(), "embedding_bag_backward_cuda_max", [&] () {
|
||||||
EmbeddingBag_accGradParametersKernel_max<
|
EmbeddingBag_accGradParametersKernel_max<
|
||||||
scalar_t><<<grid, block, 0, stream>>>(
|
scalar_t, index_t><<<grid, block, 0, stream>>>(
|
||||||
max_indices.data_ptr<int64_t>(), grad.data_ptr<scalar_t>(),
|
max_indices.data_ptr<index_t>(), grad.data_ptr<scalar_t>(),
|
||||||
grad_weight.data_ptr<scalar_t>(), stride, numBags);
|
grad_weight.data_ptr<scalar_t>(), stride, numBags);
|
||||||
});
|
});
|
||||||
|
});
|
||||||
|
|
||||||
AT_CUDA_CHECK(cudaGetLastError());
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
return grad_weight;
|
return grad_weight;
|
||||||
|
|
@ -275,9 +280,10 @@ _embedding_bag_cuda(const Tensor &weight, const Tensor &indices,
|
||||||
const Tensor& per_sample_weights,
|
const Tensor& per_sample_weights,
|
||||||
bool include_last_offset) {
|
bool include_last_offset) {
|
||||||
auto indices_arg = TensorArg(indices, "indices", 1);
|
auto indices_arg = TensorArg(indices, "indices", 1);
|
||||||
checkScalarType("embedding_bag_cuda", indices_arg, kLong);
|
checkScalarTypes("embedding_bag_cuda", indices_arg, {kLong, kInt});
|
||||||
auto offsets_arg = TensorArg(offsets, "offsets", 1);
|
auto offsets_arg = TensorArg(offsets, "offsets", 1);
|
||||||
checkScalarType("embedding_bag_cuda", offsets_arg, kLong);
|
checkScalarTypes("embedding_bag_cuda", offsets_arg, {kLong, kInt});
|
||||||
|
checkSameType("embedding_bag_cuda", indices_arg, offsets_arg);
|
||||||
auto weight_arg = TensorArg(weight, "weight", 1);
|
auto weight_arg = TensorArg(weight, "weight", 1);
|
||||||
checkSameGPU("embedding_bag_cuda", weight_arg, indices_arg);
|
checkSameGPU("embedding_bag_cuda", weight_arg, indices_arg);
|
||||||
checkSameGPU("embedding_bag_cuda", weight_arg, offsets_arg);
|
checkSameGPU("embedding_bag_cuda", weight_arg, offsets_arg);
|
||||||
|
|
@ -320,16 +326,18 @@ _embedding_bag_cuda(const Tensor &weight, const Tensor &indices,
|
||||||
int grid = 1024;
|
int grid = 1024;
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, weight.scalar_type(), "embedding_bag_cuda", [&] {
|
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, weight.scalar_type(), "embedding_bag_cuda", [&] {
|
||||||
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "embedding_bag_cuda", [&] {
|
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "embedding_bag_cuda", [&] {
|
||||||
EmbeddingBag_updateOutputKernel<scalar_t><<<grid, block, 0, stream>>>(
|
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_cuda", [&] () {
|
||||||
indices.data_ptr<int64_t>(), offsets.data_ptr<int64_t>(),
|
EmbeddingBag_updateOutputKernel<scalar_t, index_t><<<grid, block, 0, stream>>>(
|
||||||
|
indices.data_ptr<index_t>(), offsets.data_ptr<index_t>(),
|
||||||
weight.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
|
weight.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
|
||||||
offset2bag.data_ptr<int64_t>(), numIndices, numBags, featureSize,
|
offset2bag.data_ptr<index_t>(), numIndices, numBags, featureSize,
|
||||||
weight.stride(0), weight.stride(1), mode, bag_size.data_ptr<int64_t>(),
|
weight.stride(0), weight.stride(1), mode, bag_size.data_ptr<index_t>(),
|
||||||
mode == MODE_MAX ? max_indices.data_ptr<int64_t>() : NULL,
|
mode == MODE_MAX ? max_indices.data_ptr<index_t>() : NULL,
|
||||||
per_sample_weights.defined() ? per_sample_weights.data_ptr<scalar_t>() : NULL,
|
per_sample_weights.defined() ? per_sample_weights.data_ptr<scalar_t>() : NULL,
|
||||||
per_sample_weights.defined() ? per_sample_weights.stride(0) : 0);
|
per_sample_weights.defined() ? per_sample_weights.stride(0) : 0);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
});
|
||||||
|
|
||||||
AT_CUDA_CHECK(cudaGetLastError());
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, offset2bag, bag_size, max_indices);
|
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, offset2bag, bag_size, max_indices);
|
||||||
|
|
@ -387,12 +395,12 @@ static scalar_t warpReduceSum(scalar_t val) {
|
||||||
return val;
|
return val;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t, typename index_t>
|
||||||
__global__ static void _embedding_bag_per_sample_weights_backward_kernel(
|
__global__ static void _embedding_bag_per_sample_weights_backward_kernel(
|
||||||
const scalar_t* grad, int64_t grad_stride0, int64_t grad_stride1,
|
const scalar_t* grad, int64_t grad_stride0, int64_t grad_stride1,
|
||||||
const scalar_t* weight, int64_t weight_stride0, int64_t weight_stride1,
|
const scalar_t* weight, int64_t weight_stride0, int64_t weight_stride1,
|
||||||
const int64_t* indices, // contiguous
|
const index_t* indices, // contiguous
|
||||||
const int64_t* offset2bag, // contiguous
|
const index_t* offset2bag, // contiguous
|
||||||
int64_t num_samples,
|
int64_t num_samples,
|
||||||
int64_t embedding_features,
|
int64_t embedding_features,
|
||||||
scalar_t* output) {
|
scalar_t* output) {
|
||||||
|
|
@ -457,16 +465,18 @@ Tensor _embedding_bag_per_sample_weights_backward_cuda(
|
||||||
|
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||||
grad.scalar_type(), "_embedding_bag_per_sample_weights_backward_cuda", [&]() {
|
grad.scalar_type(), "_embedding_bag_per_sample_weights_backward_cuda", [&]() {
|
||||||
_embedding_bag_per_sample_weights_backward_kernel<scalar_t>
|
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "_embedding_bag_per_sample_weights_backward_cuda", [&]() {
|
||||||
|
_embedding_bag_per_sample_weights_backward_kernel<scalar_t, index_t>
|
||||||
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
|
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
grad.data_ptr<scalar_t>(), grad.stride(0), grad.stride(1),
|
grad.data_ptr<scalar_t>(), grad.stride(0), grad.stride(1),
|
||||||
weight.data_ptr<scalar_t>(), weight.stride(0), weight.stride(1),
|
weight.data_ptr<scalar_t>(), weight.stride(0), weight.stride(1),
|
||||||
indices.data_ptr<int64_t>(),
|
indices.data_ptr<index_t>(),
|
||||||
offset2bag.data_ptr<int64_t>(),
|
offset2bag.data_ptr<index_t>(),
|
||||||
num_samples,
|
num_samples,
|
||||||
embedding_features,
|
embedding_features,
|
||||||
output.data_ptr<scalar_t>());
|
output.data_ptr<scalar_t>());
|
||||||
AT_CUDA_CHECK(cudaGetLastError());
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
|
});
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
return output;
|
return output;
|
||||||
|
|
|
||||||
|
|
@ -308,10 +308,10 @@ static ptrdiff_t getSliceSize(const Tensor & dst,
|
||||||
// the number of indices chosen is large, then the
|
// the number of indices chosen is large, then the
|
||||||
// indexAddLargeIndex kernel is a better choice to increase
|
// indexAddLargeIndex kernel is a better choice to increase
|
||||||
// parallelism.
|
// parallelism.
|
||||||
template <typename T, typename IndexType, int DstDim, int SrcDim, int IdxDim>
|
template <typename T, typename IndicesType, typename IndexType, int DstDim, int SrcDim, int IdxDim>
|
||||||
__global__ void indexAddSmallIndex(cuda::detail::TensorInfo<T, IndexType> dst,
|
__global__ void indexAddSmallIndex(cuda::detail::TensorInfo<T, IndexType> dst,
|
||||||
cuda::detail::TensorInfo<T, IndexType> src,
|
cuda::detail::TensorInfo<T, IndexType> src,
|
||||||
cuda::detail::TensorInfo<int64_t, IndexType> indices,
|
cuda::detail::TensorInfo<IndicesType, IndexType> indices,
|
||||||
int dstAddDim,
|
int dstAddDim,
|
||||||
int srcAddDim,
|
int srcAddDim,
|
||||||
IndexType innerSize,
|
IndexType innerSize,
|
||||||
|
|
@ -324,7 +324,7 @@ __global__ void indexAddSmallIndex(cuda::detail::TensorInfo<T, IndexType> dst,
|
||||||
for (IndexType srcIndex = 0; srcIndex < indices.sizes[0]; ++srcIndex) {
|
for (IndexType srcIndex = 0; srcIndex < indices.sizes[0]; ++srcIndex) {
|
||||||
// Lua indices begin at 1
|
// Lua indices begin at 1
|
||||||
IndexType dstIndex =
|
IndexType dstIndex =
|
||||||
indices.data[cuda::detail::IndexToOffset<int64_t, IndexType, IdxDim>::get(srcIndex, indices)];
|
indices.data[cuda::detail::IndexToOffset<IndicesType, IndexType, IdxDim>::get(srcIndex, indices)];
|
||||||
CUDA_KERNEL_ASSERT(dstIndex < dstAddDimSize);
|
CUDA_KERNEL_ASSERT(dstIndex < dstAddDimSize);
|
||||||
|
|
||||||
// We stride over the output ignoring the indexed dimension
|
// We stride over the output ignoring the indexed dimension
|
||||||
|
|
@ -351,11 +351,11 @@ __global__ void indexAddSmallIndex(cuda::detail::TensorInfo<T, IndexType> dst,
|
||||||
// the number of indices chosen is small, then the
|
// the number of indices chosen is small, then the
|
||||||
// indexAddSmallIndex kernel is a better choice to reduce memory
|
// indexAddSmallIndex kernel is a better choice to reduce memory
|
||||||
// accesses.
|
// accesses.
|
||||||
template <typename T, typename IndexType, int DstDim, int SrcDim, int IdxDim,
|
template <typename T, typename IndicesType, typename IndexType, int DstDim, int SrcDim, int IdxDim,
|
||||||
bool IndexIsMajor>
|
bool IndexIsMajor>
|
||||||
__global__ void indexAddLargeIndex(cuda::detail::TensorInfo<T, IndexType> dst,
|
__global__ void indexAddLargeIndex(cuda::detail::TensorInfo<T, IndexType> dst,
|
||||||
cuda::detail::TensorInfo<T, IndexType> src,
|
cuda::detail::TensorInfo<T, IndexType> src,
|
||||||
cuda::detail::TensorInfo<int64_t, IndexType> indices,
|
cuda::detail::TensorInfo<IndicesType, IndexType> indices,
|
||||||
int dstAddDim,
|
int dstAddDim,
|
||||||
int srcAddDim,
|
int srcAddDim,
|
||||||
IndexType totalSize,
|
IndexType totalSize,
|
||||||
|
|
@ -378,7 +378,7 @@ __global__ void indexAddLargeIndex(cuda::detail::TensorInfo<T, IndexType> dst,
|
||||||
|
|
||||||
// Lua indices begin at 1
|
// Lua indices begin at 1
|
||||||
IndexType dstIndex =
|
IndexType dstIndex =
|
||||||
indices.data[cuda::detail::IndexToOffset<int64_t, IndexType, IdxDim>::get(srcIndex, indices)];
|
indices.data[cuda::detail::IndexToOffset<IndicesType, IndexType, IdxDim>::get(srcIndex, indices)];
|
||||||
CUDA_KERNEL_ASSERT(dstIndex < dstAddDimSize);
|
CUDA_KERNEL_ASSERT(dstIndex < dstAddDimSize);
|
||||||
|
|
||||||
IndexType dstOffset =
|
IndexType dstOffset =
|
||||||
|
|
@ -438,7 +438,7 @@ Tensor& index_add_cuda_(Tensor & self, int64_t dim, const Tensor & index, const
|
||||||
checkAllSameGPU("index_add", {self_arg, index_arg, source_arg});
|
checkAllSameGPU("index_add", {self_arg, index_arg, source_arg});
|
||||||
|
|
||||||
TORCH_CHECK_INDEX(index.dim() <= 1, "index_add_(): Index is supposed to be a vector");
|
TORCH_CHECK_INDEX(index.dim() <= 1, "index_add_(): Index is supposed to be a vector");
|
||||||
TORCH_CHECK(index.scalar_type() == ScalarType::Long, "index_add_(): Expected dtype int64 for index");
|
TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int, "index_add_(): Expected dtype int32/int64 for index");
|
||||||
TORCH_CHECK(self.scalar_type() == source.scalar_type(),
|
TORCH_CHECK(self.scalar_type() == source.scalar_type(),
|
||||||
"index_add_(): self and source must have the same scalar type");
|
"index_add_(): self and source must have the same scalar type");
|
||||||
TORCH_CHECK(dim == 0 || dim < source.dim(),
|
TORCH_CHECK(dim == 0 || dim < source.dim(),
|
||||||
|
|
@ -476,15 +476,15 @@ Tensor& index_add_cuda_(Tensor & self, int64_t dim, const Tensor & index, const
|
||||||
|
|
||||||
int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
|
int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
|
||||||
|
|
||||||
#define SMALL_INDEX(TENSOR_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM) \
|
#define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM) \
|
||||||
indexAddSmallIndex<TENSOR_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM> \
|
indexAddSmallIndex<TENSOR_TYPE, INDICES_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM> \
|
||||||
<<<smallIndexGrid, smallIndexBlock, 0, stream>>>( \
|
<<<smallIndexGrid, smallIndexBlock, 0, stream>>>( \
|
||||||
selfInfo, sourceInfo, indexInfo, \
|
selfInfo, sourceInfo, indexInfo, \
|
||||||
selfAddDim, sourceAddDim, sliceSize, selfAddDimSize);
|
selfAddDim, sourceAddDim, sliceSize, selfAddDimSize);
|
||||||
|
|
||||||
#define LARGE_INDEX(TENSOR_TYPE, TYPE, \
|
#define LARGE_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, \
|
||||||
SELF_DIM, SOURCE_DIM, IDX_DIM, IDX_IS_MAJOR) \
|
SELF_DIM, SOURCE_DIM, IDX_DIM, IDX_IS_MAJOR) \
|
||||||
indexAddLargeIndex<TENSOR_TYPE, TYPE, \
|
indexAddLargeIndex<TENSOR_TYPE, INDICES_TYPE, TYPE, \
|
||||||
SELF_DIM, SOURCE_DIM, IDX_DIM, IDX_IS_MAJOR> \
|
SELF_DIM, SOURCE_DIM, IDX_DIM, IDX_IS_MAJOR> \
|
||||||
<<<largeIndexGrid, largeIndexBlock, 0, stream>>>( \
|
<<<largeIndexGrid, largeIndexBlock, 0, stream>>>( \
|
||||||
selfInfo, sourceInfo, indexInfo, \
|
selfInfo, sourceInfo, indexInfo, \
|
||||||
|
|
@ -507,51 +507,52 @@ Tensor& index_add_cuda_(Tensor & self, int64_t dim, const Tensor & index, const
|
||||||
cuda::detail::getTensorInfo<scalar_t, unsigned int>(self_);
|
cuda::detail::getTensorInfo<scalar_t, unsigned int>(self_);
|
||||||
int selfAddDim = selfInfo.collapseDims(dim);
|
int selfAddDim = selfInfo.collapseDims(dim);
|
||||||
selfInfo.reduceDim(selfAddDim);
|
selfInfo.reduceDim(selfAddDim);
|
||||||
|
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cuda_", [&] () {
|
||||||
auto sourceInfo =
|
auto sourceInfo =
|
||||||
cuda::detail::getTensorInfo<scalar_t, unsigned int>(source_);
|
cuda::detail::getTensorInfo<scalar_t, unsigned int>(source_);
|
||||||
int sourceAddDim = sourceInfo.collapseDims(dim);
|
int sourceAddDim = sourceInfo.collapseDims(dim);
|
||||||
sourceInfo.reduceDim(sourceAddDim);
|
sourceInfo.reduceDim(sourceAddDim);
|
||||||
|
|
||||||
auto indexInfo =
|
auto indexInfo =
|
||||||
cuda::detail::getTensorInfo<int64_t, unsigned int>(index);
|
cuda::detail::getTensorInfo<index_t, unsigned int>(index);
|
||||||
indexInfo.collapseDims();
|
indexInfo.collapseDims();
|
||||||
|
|
||||||
// A reasonable choice for when to have each thread iterate over
|
// A reasonable choice for when to have each thread iterate over
|
||||||
// index to choose
|
// index to choose
|
||||||
if (numIndex <= 16) {
|
if (numIndex <= 16) {
|
||||||
if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) {
|
if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) {
|
||||||
SMALL_INDEX(scalar_t, unsigned int, 1, 1, -2);
|
SMALL_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2);
|
||||||
} else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) {
|
} else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) {
|
||||||
SMALL_INDEX(scalar_t, unsigned int, 2, 2, -2);
|
SMALL_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2);
|
||||||
} else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) {
|
} else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) {
|
||||||
SMALL_INDEX(scalar_t, unsigned int, 3, 3, -2);
|
SMALL_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2);
|
||||||
} else {
|
} else {
|
||||||
SMALL_INDEX(scalar_t, unsigned int, -1, -1, -1);
|
SMALL_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
bool indexIsMajor = indexShouldBeMajor(selfInfo, selfAddDim);
|
bool indexIsMajor = indexShouldBeMajor(selfInfo, selfAddDim);
|
||||||
|
|
||||||
if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) {
|
if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) {
|
||||||
LARGE_INDEX(scalar_t, unsigned int, 1, 1, -2, true);
|
LARGE_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2, true);
|
||||||
} else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) {
|
} else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) {
|
||||||
if (indexIsMajor) {
|
if (indexIsMajor) {
|
||||||
LARGE_INDEX(scalar_t, unsigned int, 2, 2, -2, true);
|
LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, true);
|
||||||
} else {
|
} else {
|
||||||
LARGE_INDEX(scalar_t, unsigned int, 2, 2, -2, false);
|
LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, false);
|
||||||
}
|
}
|
||||||
} else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) {
|
} else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) {
|
||||||
if (indexIsMajor) {
|
if (indexIsMajor) {
|
||||||
LARGE_INDEX(scalar_t, unsigned int, 3, 3, -2, true);
|
LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, true);
|
||||||
} else {
|
} else {
|
||||||
LARGE_INDEX(scalar_t, unsigned int, 3, 3, -2, false);
|
LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, false);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
LARGE_INDEX(scalar_t, unsigned int, -1, -1, -1, true);
|
LARGE_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1, true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
});
|
||||||
} else {
|
} else {
|
||||||
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "index_add", [&] {
|
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "index_add", [&] {
|
||||||
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "index_add", [&] {
|
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "index_add", [&] {
|
||||||
|
|
@ -565,11 +566,13 @@ Tensor& index_add_cuda_(Tensor & self, int64_t dim, const Tensor & index, const
|
||||||
int sourceAddDim = sourceInfo.collapseDims(dim);
|
int sourceAddDim = sourceInfo.collapseDims(dim);
|
||||||
sourceInfo.reduceDim(sourceAddDim);
|
sourceInfo.reduceDim(sourceAddDim);
|
||||||
|
|
||||||
cuda::detail::TensorInfo<int64_t, uint64_t> indexInfo =
|
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cuda_", [&] () {
|
||||||
cuda::detail::getTensorInfo<int64_t, uint64_t>(index);
|
cuda::detail::TensorInfo<index_t, uint64_t> indexInfo =
|
||||||
|
cuda::detail::getTensorInfo<index_t, uint64_t>(index);
|
||||||
indexInfo.collapseDims();
|
indexInfo.collapseDims();
|
||||||
|
|
||||||
LARGE_INDEX(scalar_t, uint64_t, -1, -1, -1, true);
|
LARGE_INDEX(scalar_t, index_t, uint64_t, -1, -1, -1, true);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -586,10 +589,10 @@ namespace {
|
||||||
// the number of indices chosen is large, then the
|
// the number of indices chosen is large, then the
|
||||||
// indexSelectLargeIndex kernel is a better choice to increase
|
// indexSelectLargeIndex kernel is a better choice to increase
|
||||||
// parallelism.
|
// parallelism.
|
||||||
template <typename T, typename IndexType, int DstDim, int SrcDim, int IdxDim>
|
template <typename T, typename IndicesType, typename IndexType, int DstDim, int SrcDim, int IdxDim>
|
||||||
__global__ void indexSelectSmallIndex(cuda::detail::TensorInfo<T, IndexType> dst,
|
__global__ void indexSelectSmallIndex(cuda::detail::TensorInfo<T, IndexType> dst,
|
||||||
cuda::detail::TensorInfo<T, IndexType> src,
|
cuda::detail::TensorInfo<T, IndexType> src,
|
||||||
cuda::detail::TensorInfo<int64_t, IndexType> indices,
|
cuda::detail::TensorInfo<IndicesType, IndexType> indices,
|
||||||
int dstSelectDim,
|
int dstSelectDim,
|
||||||
int srcSelectDim,
|
int srcSelectDim,
|
||||||
IndexType innerSize,
|
IndexType innerSize,
|
||||||
|
|
@ -601,7 +604,7 @@ __global__ void indexSelectSmallIndex(cuda::detail::TensorInfo<T, IndexType> dst
|
||||||
// re-accessing indices in addition to src elements can be slow.
|
// re-accessing indices in addition to src elements can be slow.
|
||||||
for (IndexType dstIndex = 0; dstIndex < indices.sizes[0]; ++dstIndex) {
|
for (IndexType dstIndex = 0; dstIndex < indices.sizes[0]; ++dstIndex) {
|
||||||
IndexType srcIndex =
|
IndexType srcIndex =
|
||||||
indices.data[cuda::detail::IndexToOffset<int64_t, IndexType, IdxDim>::get(dstIndex, indices)];
|
indices.data[cuda::detail::IndexToOffset<IndicesType, IndexType, IdxDim>::get(dstIndex, indices)];
|
||||||
CUDA_KERNEL_ASSERT(srcIndex < srcSelectDimSize);
|
CUDA_KERNEL_ASSERT(srcIndex < srcSelectDimSize);
|
||||||
|
|
||||||
// We stride over the output ignoring the indexed dimension
|
// We stride over the output ignoring the indexed dimension
|
||||||
|
|
@ -628,11 +631,11 @@ __global__ void indexSelectSmallIndex(cuda::detail::TensorInfo<T, IndexType> dst
|
||||||
// the number of indices chosen is small, then the
|
// the number of indices chosen is small, then the
|
||||||
// indexSelectSmallIndex kernel is a better choice to reduce memory
|
// indexSelectSmallIndex kernel is a better choice to reduce memory
|
||||||
// accesses.
|
// accesses.
|
||||||
template <typename T, typename IndexType, int DstDim, int SrcDim, int IdxDim,
|
template <typename T, typename IndicesType, typename IndexType, int DstDim, int SrcDim, int IdxDim,
|
||||||
bool IndexIsMajor>
|
bool IndexIsMajor>
|
||||||
__global__ void indexSelectLargeIndex(cuda::detail::TensorInfo<T, IndexType> dst,
|
__global__ void indexSelectLargeIndex(cuda::detail::TensorInfo<T, IndexType> dst,
|
||||||
cuda::detail::TensorInfo<T, IndexType> src,
|
cuda::detail::TensorInfo<T, IndexType> src,
|
||||||
cuda::detail::TensorInfo<int64_t, IndexType> indices,
|
cuda::detail::TensorInfo<IndicesType, IndexType> indices,
|
||||||
int dstSelectDim,
|
int dstSelectDim,
|
||||||
int srcSelectDim,
|
int srcSelectDim,
|
||||||
IndexType totalSize,
|
IndexType totalSize,
|
||||||
|
|
@ -654,7 +657,7 @@ __global__ void indexSelectLargeIndex(cuda::detail::TensorInfo<T, IndexType> dst
|
||||||
}
|
}
|
||||||
|
|
||||||
IndexType srcIndex =
|
IndexType srcIndex =
|
||||||
indices.data[cuda::detail::IndexToOffset<int64_t, IndexType, IdxDim>::get(dstIndex, indices)];
|
indices.data[cuda::detail::IndexToOffset<IndicesType, IndexType, IdxDim>::get(dstIndex, indices)];
|
||||||
CUDA_KERNEL_ASSERT(srcIndex < srcSelectDimSize);
|
CUDA_KERNEL_ASSERT(srcIndex < srcSelectDimSize);
|
||||||
|
|
||||||
IndexType dstOffset =
|
IndexType dstOffset =
|
||||||
|
|
@ -722,16 +725,16 @@ void index_select_out_cuda_impl(Tensor& out, const Tensor& self, long dim,
|
||||||
|
|
||||||
int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
|
int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
|
||||||
|
|
||||||
#define SMALL_INDEX(TENSOR_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM) \
|
#define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM) \
|
||||||
indexSelectSmallIndex<TENSOR_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM> \
|
indexSelectSmallIndex<TENSOR_TYPE, INDICES_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM> \
|
||||||
<<<smallIndexGrid, smallIndexBlock, 0, stream>>>( \
|
<<<smallIndexGrid, smallIndexBlock, 0, stream>>>( \
|
||||||
outInfo, selfInfo, indicesInfo, \
|
outInfo, selfInfo, indicesInfo, \
|
||||||
outSelectDim, selfSelectDim, static_cast<TYPE>(sliceSize), \
|
outSelectDim, selfSelectDim, static_cast<TYPE>(sliceSize), \
|
||||||
selfSelectDimSize);
|
selfSelectDimSize);
|
||||||
|
|
||||||
#define LARGE_INDEX(TENSOR_TYPE, TYPE, \
|
#define LARGE_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, \
|
||||||
DST_DIM, SRC_DIM, IDX_DIM, IDX_IS_MAJOR) \
|
DST_DIM, SRC_DIM, IDX_DIM, IDX_IS_MAJOR) \
|
||||||
indexSelectLargeIndex<TENSOR_TYPE, TYPE, \
|
indexSelectLargeIndex<TENSOR_TYPE, INDICES_TYPE, TYPE, \
|
||||||
DST_DIM, SRC_DIM, IDX_DIM, IDX_IS_MAJOR> \
|
DST_DIM, SRC_DIM, IDX_DIM, IDX_IS_MAJOR> \
|
||||||
<<<largeIndexGrid, largeIndexBlock, 0, stream>>>( \
|
<<<largeIndexGrid, largeIndexBlock, 0, stream>>>( \
|
||||||
outInfo, selfInfo, indicesInfo, \
|
outInfo, selfInfo, indicesInfo, \
|
||||||
|
|
@ -755,42 +758,44 @@ void index_select_out_cuda_impl(Tensor& out, const Tensor& self, long dim,
|
||||||
int selfSelectDim = selfInfo.collapseDims(dim);
|
int selfSelectDim = selfInfo.collapseDims(dim);
|
||||||
selfInfo.reduceDim(selfSelectDim);
|
selfInfo.reduceDim(selfSelectDim);
|
||||||
|
|
||||||
auto indicesInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo<int64_t, unsigned int>(index));
|
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_select_out_cuda_impl", [&] () {
|
||||||
|
auto indicesInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo<index_t, unsigned int>(index));
|
||||||
indicesInfo.collapseDims();
|
indicesInfo.collapseDims();
|
||||||
|
|
||||||
// A reasonable choice for when to have each thread iterate over
|
// A reasonable choice for when to have each thread iterate over
|
||||||
// indices to choose
|
// indices to choose
|
||||||
if (numIndices <= 16) {
|
if (numIndices <= 16) {
|
||||||
if (outInfo.dims == 1 && selfInfo.dims == 1 && indContig) {
|
if (outInfo.dims == 1 && selfInfo.dims == 1 && indContig) {
|
||||||
SMALL_INDEX(scalar_t, unsigned int, 1, 1, -2);
|
SMALL_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2);
|
||||||
} else if (outInfo.dims == 2 && selfInfo.dims == 2 && indContig) {
|
} else if (outInfo.dims == 2 && selfInfo.dims == 2 && indContig) {
|
||||||
SMALL_INDEX(scalar_t, unsigned int, 2, 2, -2);
|
SMALL_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2);
|
||||||
} else if (outInfo.dims == 3 && selfInfo.dims == 3 && indContig) {
|
} else if (outInfo.dims == 3 && selfInfo.dims == 3 && indContig) {
|
||||||
SMALL_INDEX(scalar_t, unsigned int, 3, 3, -2);
|
SMALL_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2);
|
||||||
} else {
|
} else {
|
||||||
SMALL_INDEX(scalar_t, unsigned int, -1, -1, -1);
|
SMALL_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
bool indexIsMajor = indexShouldBeMajor(outInfo, outSelectDim);
|
bool indexIsMajor = indexShouldBeMajor(outInfo, outSelectDim);
|
||||||
|
|
||||||
if (outInfo.dims == 1 && selfInfo.dims == 1 && indContig) {
|
if (outInfo.dims == 1 && selfInfo.dims == 1 && indContig) {
|
||||||
LARGE_INDEX(scalar_t, unsigned int, 1, 1, -2, true);
|
LARGE_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2, true);
|
||||||
} else if (outInfo.dims == 2 && selfInfo.dims == 2 && indContig) {
|
} else if (outInfo.dims == 2 && selfInfo.dims == 2 && indContig) {
|
||||||
if (indexIsMajor) {
|
if (indexIsMajor) {
|
||||||
LARGE_INDEX(scalar_t, unsigned int, 2, 2, -2, true);
|
LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, true);
|
||||||
} else {
|
} else {
|
||||||
LARGE_INDEX(scalar_t, unsigned int, 2, 2, -2, false);
|
LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, false);
|
||||||
}
|
}
|
||||||
} else if (outInfo.dims == 3 && selfInfo.dims == 3 && indContig) {
|
} else if (outInfo.dims == 3 && selfInfo.dims == 3 && indContig) {
|
||||||
if (indexIsMajor) {
|
if (indexIsMajor) {
|
||||||
LARGE_INDEX(scalar_t, unsigned int, 3, 3, -2, true);
|
LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, true);
|
||||||
} else {
|
} else {
|
||||||
LARGE_INDEX(scalar_t, unsigned int, 3, 3, -2, false);
|
LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, false);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
LARGE_INDEX(scalar_t, unsigned int, -1, -1, -1, true);
|
LARGE_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1, true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
});
|
||||||
} else {
|
} else {
|
||||||
auto outInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo<scalar_t, uint64_t>(out));
|
auto outInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo<scalar_t, uint64_t>(out));
|
||||||
int outSelectDim = outInfo.collapseDims(dim);
|
int outSelectDim = outInfo.collapseDims(dim);
|
||||||
|
|
@ -799,11 +804,12 @@ void index_select_out_cuda_impl(Tensor& out, const Tensor& self, long dim,
|
||||||
auto selfInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo<scalar_t, uint64_t>(self));
|
auto selfInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo<scalar_t, uint64_t>(self));
|
||||||
int selfSelectDim = selfInfo.collapseDims(dim);
|
int selfSelectDim = selfInfo.collapseDims(dim);
|
||||||
selfInfo.reduceDim(selfSelectDim);
|
selfInfo.reduceDim(selfSelectDim);
|
||||||
|
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_select_out_cuda_impl", [&] () {
|
||||||
auto indicesInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo<int64_t, uint64_t>(index));
|
auto indicesInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo<index_t, uint64_t>(index));
|
||||||
indicesInfo.collapseDims();
|
indicesInfo.collapseDims();
|
||||||
|
|
||||||
LARGE_INDEX(scalar_t, uint64_t, -1, -1, -1, true);
|
LARGE_INDEX(scalar_t, index_t, uint64_t, -1, -1, -1, true);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
#undef SMALL_INDEX
|
#undef SMALL_INDEX
|
||||||
#undef LARGE_INDEX
|
#undef LARGE_INDEX
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ static bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_float_float__avx2_fma(
|
||||||
const int64_t data_size,
|
const int64_t data_size,
|
||||||
const float* input,
|
const float* input,
|
||||||
const int* indices,
|
const int* indices,
|
||||||
const int64_t* offsets,
|
const int* offsets,
|
||||||
const float* weights,
|
const float* weights,
|
||||||
bool normalize_by_lengths,
|
bool normalize_by_lengths,
|
||||||
float* out) {
|
float* out) {
|
||||||
|
|
@ -401,7 +401,7 @@ bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_float_float_false__avx2_fma(
|
||||||
const int64_t data_size,
|
const int64_t data_size,
|
||||||
const float* input,
|
const float* input,
|
||||||
const int* indices,
|
const int* indices,
|
||||||
const int64_t* offsets,
|
const int* offsets,
|
||||||
const float* weights,
|
const float* weights,
|
||||||
bool normalize_by_lengths,
|
bool normalize_by_lengths,
|
||||||
float* out) {
|
float* out) {
|
||||||
|
|
@ -425,7 +425,7 @@ bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_float_float_true__avx2_fma(
|
||||||
const int64_t data_size,
|
const int64_t data_size,
|
||||||
const float* input,
|
const float* input,
|
||||||
const int* indices,
|
const int* indices,
|
||||||
const int64_t* offsets,
|
const int* offsets,
|
||||||
const float* weights,
|
const float* weights,
|
||||||
bool normalize_by_lengths,
|
bool normalize_by_lengths,
|
||||||
float* out) {
|
float* out) {
|
||||||
|
|
@ -883,7 +883,7 @@ static bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_half_float__avx2_fma(
|
||||||
const int64_t data_size,
|
const int64_t data_size,
|
||||||
const at::Half* input,
|
const at::Half* input,
|
||||||
const int* indices,
|
const int* indices,
|
||||||
const int64_t* offsets,
|
const int* offsets,
|
||||||
const float* weights,
|
const float* weights,
|
||||||
bool normalize_by_lengths,
|
bool normalize_by_lengths,
|
||||||
float* out) {
|
float* out) {
|
||||||
|
|
@ -1387,7 +1387,7 @@ bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_half_float_false__avx2_fma(
|
||||||
const int64_t data_size,
|
const int64_t data_size,
|
||||||
const at::Half* input,
|
const at::Half* input,
|
||||||
const int* indices,
|
const int* indices,
|
||||||
const int64_t* offsets,
|
const int* offsets,
|
||||||
const float* weights,
|
const float* weights,
|
||||||
bool normalize_by_lengths,
|
bool normalize_by_lengths,
|
||||||
float* out) {
|
float* out) {
|
||||||
|
|
@ -1410,7 +1410,7 @@ bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_half_float_true__avx2_fma(
|
||||||
const int64_t data_size,
|
const int64_t data_size,
|
||||||
const at::Half* input,
|
const at::Half* input,
|
||||||
const int* indices,
|
const int* indices,
|
||||||
const int64_t* offsets,
|
const int* offsets,
|
||||||
const float* weights,
|
const float* weights,
|
||||||
bool normalize_by_lengths,
|
bool normalize_by_lengths,
|
||||||
float* out) {
|
float* out) {
|
||||||
|
|
@ -1987,7 +1987,7 @@ static bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma(
|
||||||
const int64_t data_size,
|
const int64_t data_size,
|
||||||
const uint8_t* input,
|
const uint8_t* input,
|
||||||
const int* indices,
|
const int* indices,
|
||||||
const int64_t* offsets,
|
const int* offsets,
|
||||||
const float* weights,
|
const float* weights,
|
||||||
bool normalize_by_lengths,
|
bool normalize_by_lengths,
|
||||||
float* out) {
|
float* out) {
|
||||||
|
|
@ -2514,7 +2514,7 @@ bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_uint8_t_float_false__avx2_fma(
|
||||||
const int64_t data_size,
|
const int64_t data_size,
|
||||||
const uint8_t* input,
|
const uint8_t* input,
|
||||||
const int* indices,
|
const int* indices,
|
||||||
const int64_t* offsets,
|
const int* offsets,
|
||||||
const float* weights,
|
const float* weights,
|
||||||
bool normalize_by_lengths,
|
bool normalize_by_lengths,
|
||||||
float* out) {
|
float* out) {
|
||||||
|
|
@ -2538,7 +2538,7 @@ bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_uint8_t_float_true__avx2_fma(
|
||||||
const int64_t data_size,
|
const int64_t data_size,
|
||||||
const uint8_t* input,
|
const uint8_t* input,
|
||||||
const int* indices,
|
const int* indices,
|
||||||
const int64_t* offsets,
|
const int* offsets,
|
||||||
const float* weights,
|
const float* weights,
|
||||||
bool normalize_by_lengths,
|
bool normalize_by_lengths,
|
||||||
float* out) {
|
float* out) {
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ static bool EmbeddingLookupGenericSlowIdx(
|
||||||
const int64_t data_size,
|
const int64_t data_size,
|
||||||
const InType* input,
|
const InType* input,
|
||||||
const IndexType* indices,
|
const IndexType* indices,
|
||||||
const int64_t* offsets,
|
const IndexType* offsets,
|
||||||
const float* weights, // optional, can be null for sum reducer
|
const float* weights, // optional, can be null for sum reducer
|
||||||
const float* scale_bias, // optional scale & bias params for uint8 input
|
const float* scale_bias, // optional scale & bias params for uint8 input
|
||||||
bool normalize_by_lengths,
|
bool normalize_by_lengths,
|
||||||
|
|
@ -85,7 +85,7 @@ static bool EmbeddingLookupGenericSlowIdx(
|
||||||
const int64_t data_size, \
|
const int64_t data_size, \
|
||||||
const InType* input, \
|
const InType* input, \
|
||||||
const IndexType* indices, \
|
const IndexType* indices, \
|
||||||
const int64_t* offsets, \
|
const IndexType* offsets, \
|
||||||
const float* weights, \
|
const float* weights, \
|
||||||
const float* scale_bias, \
|
const float* scale_bias, \
|
||||||
bool normalize_by_lengths, \
|
bool normalize_by_lengths, \
|
||||||
|
|
@ -118,7 +118,7 @@ static bool EmbeddingLookupGenericSlowIdx(
|
||||||
const int64_t data_size, \
|
const int64_t data_size, \
|
||||||
const InType* input, \
|
const InType* input, \
|
||||||
const IndexType* indices, \
|
const IndexType* indices, \
|
||||||
const int64_t* offsets, \
|
const IndexType* offsets, \
|
||||||
const float* weights, \
|
const float* weights, \
|
||||||
const float* scale_bias, \
|
const float* scale_bias, \
|
||||||
bool normalize_by_lengths, \
|
bool normalize_by_lengths, \
|
||||||
|
|
@ -163,7 +163,7 @@ static bool EmbeddingLookupGenericSlowIdx(
|
||||||
const int64_t data_size, \
|
const int64_t data_size, \
|
||||||
const InType* input, \
|
const InType* input, \
|
||||||
const IndexType* indices, \
|
const IndexType* indices, \
|
||||||
const int64_t* offsets, \
|
const IndexType* offsets, \
|
||||||
const float* weights, \
|
const float* weights, \
|
||||||
const float* scale_bias, \
|
const float* scale_bias, \
|
||||||
bool normalize_by_lengths, \
|
bool normalize_by_lengths, \
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ void EmbeddingLookupIdx(
|
||||||
const std::int64_t data_size,
|
const std::int64_t data_size,
|
||||||
const InType* input,
|
const InType* input,
|
||||||
const IndexType* indices,
|
const IndexType* indices,
|
||||||
const int64_t* offsets,
|
const IndexType* offsets,
|
||||||
const float* weights, // optional, can be null for non-weighted sum
|
const float* weights, // optional, can be null for non-weighted sum
|
||||||
const float* scale_bias, // optional scale & bias params for uint8 input
|
const float* scale_bias, // optional scale & bias params for uint8 input
|
||||||
bool normalize_by_lengths,
|
bool normalize_by_lengths,
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ static bool EmbeddingLookupIdx_int32_t_float_float__avx2_fma(
|
||||||
const int64_t data_size,
|
const int64_t data_size,
|
||||||
const float* input,
|
const float* input,
|
||||||
const int* indices,
|
const int* indices,
|
||||||
const int64_t* offsets,
|
const int* offsets,
|
||||||
const float* weights,
|
const float* weights,
|
||||||
const float* scale_bias,
|
const float* scale_bias,
|
||||||
bool normalize_by_lengths,
|
bool normalize_by_lengths,
|
||||||
|
|
@ -402,7 +402,7 @@ bool EmbeddingLookupIdx_int32_t_float_float_false__avx2_fma(
|
||||||
const int64_t data_size,
|
const int64_t data_size,
|
||||||
const float* input,
|
const float* input,
|
||||||
const int* indices,
|
const int* indices,
|
||||||
const int64_t* offsets,
|
const int* offsets,
|
||||||
const float* weights,
|
const float* weights,
|
||||||
const float* scale_bias,
|
const float* scale_bias,
|
||||||
bool normalize_by_lengths,
|
bool normalize_by_lengths,
|
||||||
|
|
@ -427,7 +427,7 @@ bool EmbeddingLookupIdx_int32_t_float_float_true__avx2_fma(
|
||||||
const int64_t data_size,
|
const int64_t data_size,
|
||||||
const float* input,
|
const float* input,
|
||||||
const int* indices,
|
const int* indices,
|
||||||
const int64_t* offsets,
|
const int* offsets,
|
||||||
const float* weights,
|
const float* weights,
|
||||||
const float* scale_bias,
|
const float* scale_bias,
|
||||||
bool normalize_by_lengths,
|
bool normalize_by_lengths,
|
||||||
|
|
@ -891,7 +891,7 @@ static bool EmbeddingLookupIdx_int32_t_half_float__avx2_fma(
|
||||||
const int64_t data_size,
|
const int64_t data_size,
|
||||||
const at::Half* input,
|
const at::Half* input,
|
||||||
const int* indices,
|
const int* indices,
|
||||||
const int64_t* offsets,
|
const int* offsets,
|
||||||
const float* weights,
|
const float* weights,
|
||||||
const float* scale_bias,
|
const float* scale_bias,
|
||||||
bool normalize_by_lengths,
|
bool normalize_by_lengths,
|
||||||
|
|
@ -1396,7 +1396,7 @@ bool EmbeddingLookupIdx_int32_t_half_float_false__avx2_fma(
|
||||||
const int64_t data_size,
|
const int64_t data_size,
|
||||||
const at::Half* input,
|
const at::Half* input,
|
||||||
const int* indices,
|
const int* indices,
|
||||||
const int64_t* offsets,
|
const int* offsets,
|
||||||
const float* weights,
|
const float* weights,
|
||||||
const float* scale_bias,
|
const float* scale_bias,
|
||||||
bool normalize_by_lengths,
|
bool normalize_by_lengths,
|
||||||
|
|
@ -1421,7 +1421,7 @@ bool EmbeddingLookupIdx_int32_t_half_float_true__avx2_fma(
|
||||||
const int64_t data_size,
|
const int64_t data_size,
|
||||||
const at::Half* input,
|
const at::Half* input,
|
||||||
const int* indices,
|
const int* indices,
|
||||||
const int64_t* offsets,
|
const int* offsets,
|
||||||
const float* weights,
|
const float* weights,
|
||||||
const float* scale_bias,
|
const float* scale_bias,
|
||||||
bool normalize_by_lengths,
|
bool normalize_by_lengths,
|
||||||
|
|
@ -2005,7 +2005,7 @@ static bool EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma(
|
||||||
const int64_t data_size,
|
const int64_t data_size,
|
||||||
const uint8_t* input,
|
const uint8_t* input,
|
||||||
const int* indices,
|
const int* indices,
|
||||||
const int64_t* offsets,
|
const int* offsets,
|
||||||
const float* weights,
|
const float* weights,
|
||||||
const float* scale_bias,
|
const float* scale_bias,
|
||||||
bool normalize_by_lengths,
|
bool normalize_by_lengths,
|
||||||
|
|
@ -2523,7 +2523,7 @@ bool EmbeddingLookupIdx_int32_t_uint8_t_float_false__avx2_fma(
|
||||||
const int64_t data_size,
|
const int64_t data_size,
|
||||||
const uint8_t* input,
|
const uint8_t* input,
|
||||||
const int* indices,
|
const int* indices,
|
||||||
const int64_t* offsets,
|
const int* offsets,
|
||||||
const float* weights,
|
const float* weights,
|
||||||
const float* scale_bias,
|
const float* scale_bias,
|
||||||
bool normalize_by_lengths,
|
bool normalize_by_lengths,
|
||||||
|
|
@ -2548,7 +2548,7 @@ bool EmbeddingLookupIdx_int32_t_uint8_t_float_true__avx2_fma(
|
||||||
const int64_t data_size,
|
const int64_t data_size,
|
||||||
const uint8_t* input,
|
const uint8_t* input,
|
||||||
const int* indices,
|
const int* indices,
|
||||||
const int64_t* offsets,
|
const int* offsets,
|
||||||
const float* weights,
|
const float* weights,
|
||||||
const float* scale_bias,
|
const float* scale_bias,
|
||||||
bool normalize_by_lengths,
|
bool normalize_by_lengths,
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ static bool Fused8BitRowwiseEmbeddingLookupGenericSlowIdx(
|
||||||
const int64_t data_size,
|
const int64_t data_size,
|
||||||
const InType* input,
|
const InType* input,
|
||||||
const IndexType* indices,
|
const IndexType* indices,
|
||||||
const int64_t* offsets,
|
const IndexType* offsets,
|
||||||
const float* weights, // optional, can be null for sum reducer
|
const float* weights, // optional, can be null for sum reducer
|
||||||
bool normalize_by_lengths,
|
bool normalize_by_lengths,
|
||||||
OutType* out) {
|
OutType* out) {
|
||||||
|
|
@ -88,7 +88,7 @@ static bool Fused8BitRowwiseEmbeddingLookupGenericSlowIdx(
|
||||||
const int64_t data_size, \
|
const int64_t data_size, \
|
||||||
const uint8_t* input, \
|
const uint8_t* input, \
|
||||||
const IndexType* indices, \
|
const IndexType* indices, \
|
||||||
const int64_t* offsets, \
|
const IndexType* offsets, \
|
||||||
const float* weights, \
|
const float* weights, \
|
||||||
bool normalize_by_lengths, \
|
bool normalize_by_lengths, \
|
||||||
OutType* out) { \
|
OutType* out) { \
|
||||||
|
|
@ -118,7 +118,7 @@ static bool Fused8BitRowwiseEmbeddingLookupGenericSlowIdx(
|
||||||
const int64_t data_size, \
|
const int64_t data_size, \
|
||||||
const uint8_t* input, \
|
const uint8_t* input, \
|
||||||
const IndexType* indices, \
|
const IndexType* indices, \
|
||||||
const int64_t* offsets, \
|
const IndexType* offsets, \
|
||||||
const float* weights, \
|
const float* weights, \
|
||||||
bool normalize_by_lengths, \
|
bool normalize_by_lengths, \
|
||||||
OutType* out) { \
|
OutType* out) { \
|
||||||
|
|
@ -160,7 +160,7 @@ static bool Fused8BitRowwiseEmbeddingLookupGenericSlowIdx(
|
||||||
const int64_t data_size, \
|
const int64_t data_size, \
|
||||||
const uint8_t* input, \
|
const uint8_t* input, \
|
||||||
const IndexType* indices, \
|
const IndexType* indices, \
|
||||||
const int64_t* offsets, \
|
const IndexType* offsets, \
|
||||||
const float* weights, \
|
const float* weights, \
|
||||||
bool normalize_by_lengths, \
|
bool normalize_by_lengths, \
|
||||||
OutType* out) { \
|
OutType* out) { \
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ void Fused8BitRowwiseEmbeddingLookupIdx(
|
||||||
const std::int64_t data_size,
|
const std::int64_t data_size,
|
||||||
const InType* input,
|
const InType* input,
|
||||||
const IndexType* indices,
|
const IndexType* indices,
|
||||||
const int64_t* offsets,
|
const IndexType* offsets,
|
||||||
const float* weights, // optional, can be null for non-weighted sum
|
const float* weights, // optional, can be null for non-weighted sum
|
||||||
bool normalize_by_lengths,
|
bool normalize_by_lengths,
|
||||||
OutType* out);
|
OutType* out);
|
||||||
|
|
|
||||||
|
|
@ -450,7 +450,7 @@ for o in options:
|
||||||
args.append(" const " + InType + "* input,")
|
args.append(" const " + InType + "* input,")
|
||||||
args.append(" const " + IndexType + "* indices,")
|
args.append(" const " + IndexType + "* indices,")
|
||||||
if opts.use_offsets:
|
if opts.use_offsets:
|
||||||
args.append(" const int64_t* offsets,")
|
args.append(" const " + IndexType + "* offsets,")
|
||||||
else:
|
else:
|
||||||
args.append(" const int* lengths,")
|
args.append(" const int* lengths,")
|
||||||
args.append(" const float* weights,")
|
args.append(" const float* weights,")
|
||||||
|
|
|
||||||
245
test/test_nn.py
245
test/test_nn.py
|
|
@ -11610,25 +11610,27 @@ class TestNNDeviceType(NNTestCase):
|
||||||
self.assertRaises(RuntimeError, lambda: F.adaptive_max_pool2d(t, []))
|
self.assertRaises(RuntimeError, lambda: F.adaptive_max_pool2d(t, []))
|
||||||
self.assertRaises(RuntimeError, lambda: F.adaptive_max_pool3d(t, []))
|
self.assertRaises(RuntimeError, lambda: F.adaptive_max_pool3d(t, []))
|
||||||
|
|
||||||
def test_embedding_bag_empty_input(self, device):
|
@dtypes(torch.int, torch.long)
|
||||||
|
def test_embedding_bag_empty_input(self, device, dtype):
|
||||||
m = 4
|
m = 4
|
||||||
n = 3
|
n = 3
|
||||||
x = torch.tensor([], device=device, dtype=torch.long)
|
x = torch.tensor([], device=device, dtype=dtype)
|
||||||
for sparse in [True, False]:
|
for sparse in [True, False]:
|
||||||
Embed = torch.nn.EmbeddingBag(m, n, sparse=sparse)
|
Embed = torch.nn.EmbeddingBag(m, n, sparse=sparse)
|
||||||
Embed.to(device)
|
Embed.to(device)
|
||||||
|
|
||||||
output = Embed(input=x, offsets=torch.tensor([0], device=device, dtype=torch.long))
|
output = Embed(input=x, offsets=torch.tensor([0], device=device, dtype=dtype))
|
||||||
self.assertEqual(output, torch.zeros_like(output))
|
self.assertEqual(output, torch.zeros_like(output))
|
||||||
|
|
||||||
output = Embed(input=x, offsets=torch.tensor([0, 0], device=device, dtype=torch.long))
|
output = Embed(input=x, offsets=torch.tensor([0, 0], device=device, dtype=dtype))
|
||||||
self.assertEqual(output, torch.zeros_like(output))
|
self.assertEqual(output, torch.zeros_like(output))
|
||||||
|
|
||||||
def test_EmbeddingBag_per_sample_weights_failures(self, device):
|
@dtypes(torch.int, torch.long)
|
||||||
|
def test_EmbeddingBag_per_sample_weights_failures(self, device, dtype):
|
||||||
# Failure 1: mismatched embeddings / per_sample_weights dtype
|
# Failure 1: mismatched embeddings / per_sample_weights dtype
|
||||||
es = nn.EmbeddingBag(5, 2, mode='sum').to(dtype=torch.float, device=device)
|
es = nn.EmbeddingBag(5, 2, mode='sum').to(dtype=torch.float, device=device)
|
||||||
input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=torch.long, device=device)
|
input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=dtype, device=device)
|
||||||
offsets = torch.tensor([0, 0, 3, 3, 6], dtype=torch.long, device=device)
|
offsets = torch.tensor([0, 0, 3, 3, 6], dtype=dtype, device=device)
|
||||||
per_sample_weights = torch.randn_like(input, dtype=torch.double, device=device)
|
per_sample_weights = torch.randn_like(input, dtype=torch.double, device=device)
|
||||||
if device == 'cpu':
|
if device == 'cpu':
|
||||||
with self.assertRaisesRegex(RuntimeError, 'have the same type as'):
|
with self.assertRaisesRegex(RuntimeError, 'have the same type as'):
|
||||||
|
|
@ -11638,14 +11640,14 @@ class TestNNDeviceType(NNTestCase):
|
||||||
es(input, offsets, per_sample_weights)
|
es(input, offsets, per_sample_weights)
|
||||||
|
|
||||||
# Failure 2.1: input/per_sample_weights have different sizes (1d input)
|
# Failure 2.1: input/per_sample_weights have different sizes (1d input)
|
||||||
input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=torch.long, device=device)
|
input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=dtype, device=device)
|
||||||
offsets = torch.tensor([0, 0, 3, 3, 6], dtype=torch.long, device=device)
|
offsets = torch.tensor([0, 0, 3, 3, 6], dtype=dtype, device=device)
|
||||||
per_sample_weights = torch.randn(5, dtype=torch.float, device=device)
|
per_sample_weights = torch.randn(5, dtype=torch.float, device=device)
|
||||||
with self.assertRaisesRegex(ValueError, 'same shape as the input'):
|
with self.assertRaisesRegex(ValueError, 'same shape as the input'):
|
||||||
es(input, offsets, per_sample_weights)
|
es(input, offsets, per_sample_weights)
|
||||||
|
|
||||||
# Failure 2.2: input/per_sample_weights have different sizes (2d input)
|
# Failure 2.2: input/per_sample_weights have different sizes (2d input)
|
||||||
input = torch.randint(5, (7, 3), dtype=torch.long, device=device)
|
input = torch.randint(5, (7, 3), dtype=dtype, device=device)
|
||||||
offsets = None
|
offsets = None
|
||||||
per_sample_weights = torch.randn(7 * 3, dtype=torch.float, device=device)
|
per_sample_weights = torch.randn(7 * 3, dtype=torch.float, device=device)
|
||||||
with self.assertRaisesRegex(ValueError, 'same shape as the input'):
|
with self.assertRaisesRegex(ValueError, 'same shape as the input'):
|
||||||
|
|
@ -11655,7 +11657,7 @@ class TestNNDeviceType(NNTestCase):
|
||||||
for unsupported_mode in ('max', 'mean'):
|
for unsupported_mode in ('max', 'mean'):
|
||||||
es = nn.EmbeddingBag(5, 2, mode=unsupported_mode).to(
|
es = nn.EmbeddingBag(5, 2, mode=unsupported_mode).to(
|
||||||
dtype=torch.float, device=device)
|
dtype=torch.float, device=device)
|
||||||
input = torch.randint(5, (7, 3), dtype=torch.long, device=device)
|
input = torch.randint(5, (7, 3), dtype=dtype, device=device)
|
||||||
offsets = None
|
offsets = None
|
||||||
per_sample_weights = torch.randn(7, 3, dtype=torch.float, device=device)
|
per_sample_weights = torch.randn(7, 3, dtype=torch.float, device=device)
|
||||||
with self.assertRaisesRegex(NotImplementedError,
|
with self.assertRaisesRegex(NotImplementedError,
|
||||||
|
|
@ -11673,7 +11675,8 @@ class TestNNDeviceType(NNTestCase):
|
||||||
assert input.numel() == per_sample_weights.numel()
|
assert input.numel() == per_sample_weights.numel()
|
||||||
|
|
||||||
bags = []
|
bags = []
|
||||||
embeddings = weight.index_select(0, input) * per_sample_weights.unsqueeze(1)
|
long_input = input.to(torch.long)
|
||||||
|
embeddings = weight.index_select(0, long_input) * per_sample_weights.unsqueeze(1)
|
||||||
if include_last_offset:
|
if include_last_offset:
|
||||||
for index in range(len(offsets) - 1):
|
for index in range(len(offsets) - 1):
|
||||||
offset = offsets[index]
|
offset = offsets[index]
|
||||||
|
|
@ -11698,7 +11701,7 @@ class TestNNDeviceType(NNTestCase):
|
||||||
if index + 1 < len(offsets):
|
if index + 1 < len(offsets):
|
||||||
next_offset = offsets[index + 1]
|
next_offset = offsets[index + 1]
|
||||||
else:
|
else:
|
||||||
next_offset = len(input)
|
next_offset = len(long_input)
|
||||||
length = next_offset - offset
|
length = next_offset - offset
|
||||||
if length == 0:
|
if length == 0:
|
||||||
bags.append(
|
bags.append(
|
||||||
|
|
@ -11716,16 +11719,18 @@ class TestNNDeviceType(NNTestCase):
|
||||||
bags.append(embeddings.narrow(0, offset, length).max(0)[0])
|
bags.append(embeddings.narrow(0, offset, length).max(0)[0])
|
||||||
return torch.stack(bags)
|
return torch.stack(bags)
|
||||||
|
|
||||||
def test_EmbeddingBag_empty_per_sample_weights_and_offsets(self, device):
|
@dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.float, torch.double, torch.half)))
|
||||||
|
@dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double)))
|
||||||
|
def test_EmbeddingBag_empty_per_sample_weights_and_offsets(self, device, dtypes):
|
||||||
# Test empty input and per sample weight, and backward pass. There was a CUDA
|
# Test empty input and per sample weight, and backward pass. There was a CUDA
|
||||||
# invalid configuration bug (more context in #46572)
|
# invalid configuration bug (more context in #46572)
|
||||||
def test_per_sample_weights(mode, dtype, trainable_scale):
|
def test_per_sample_weights(mode, trainable_scale):
|
||||||
es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtype, device=device)
|
es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtypes[1], device=device)
|
||||||
es.weight.data.copy_(
|
es.weight.data.copy_(
|
||||||
torch.arange(1, 11, device=device, dtype=dtype).view_as(es.weight))
|
torch.arange(1, 11, device=device, dtype=dtypes[1]).view_as(es.weight))
|
||||||
input = torch.tensor([], device=device, dtype=torch.long)
|
input = torch.tensor([], device=device, dtype=dtypes[0])
|
||||||
offsets = torch.tensor([0, 0, 0, 0, 0], device=device, dtype=torch.long)
|
offsets = torch.tensor([0, 0, 0, 0, 0], device=device, dtype=dtypes[0])
|
||||||
per_sample_weights = torch.randn_like(input, dtype=dtype) \
|
per_sample_weights = torch.randn_like(input, dtype=dtypes[1]) \
|
||||||
.requires_grad_(trainable_scale)
|
.requires_grad_(trainable_scale)
|
||||||
ref_per_sample_weights = \
|
ref_per_sample_weights = \
|
||||||
per_sample_weights.detach().requires_grad_(trainable_scale)
|
per_sample_weights.detach().requires_grad_(trainable_scale)
|
||||||
|
|
@ -11734,7 +11739,7 @@ class TestNNDeviceType(NNTestCase):
|
||||||
expected = self._embedding_bag_reference_impl(
|
expected = self._embedding_bag_reference_impl(
|
||||||
input, reference_weights, offsets, mode, ref_per_sample_weights)
|
input, reference_weights, offsets, mode, ref_per_sample_weights)
|
||||||
result = es(input, offsets, per_sample_weights)
|
result = es(input, offsets, per_sample_weights)
|
||||||
self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtype], rtol=0)
|
self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0)
|
||||||
|
|
||||||
grad = torch.randn_like(expected)
|
grad = torch.randn_like(expected)
|
||||||
result.backward(grad)
|
result.backward(grad)
|
||||||
|
|
@ -11742,29 +11747,27 @@ class TestNNDeviceType(NNTestCase):
|
||||||
# simply be a zero tensor
|
# simply be a zero tensor
|
||||||
ref_weights_grad = torch.zeros_like(es.weight)
|
ref_weights_grad = torch.zeros_like(es.weight)
|
||||||
self.assertEqual(es.weight.grad, ref_weights_grad,
|
self.assertEqual(es.weight.grad, ref_weights_grad,
|
||||||
atol=dtype2prec_DONTUSE[dtype], rtol=0)
|
atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0)
|
||||||
if trainable_scale:
|
if trainable_scale:
|
||||||
ref_per_sample_weights_grad = torch.empty_like(per_sample_weights)
|
ref_per_sample_weights_grad = torch.empty_like(per_sample_weights)
|
||||||
self.assertEqual(per_sample_weights.grad, ref_per_sample_weights_grad,
|
self.assertEqual(per_sample_weights.grad, ref_per_sample_weights_grad,
|
||||||
atol=dtype2prec_DONTUSE[dtype], rtol=0)
|
atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0)
|
||||||
|
|
||||||
if device == 'cuda':
|
|
||||||
dtypes = (torch.float, torch.double, torch.half)
|
|
||||||
else:
|
|
||||||
dtypes = (torch.float, torch.double)
|
|
||||||
modes = ('sum',)
|
modes = ('sum',)
|
||||||
trainable_scale = (True, False)
|
trainable_scale = (True, False)
|
||||||
for dtype, mode, trainable in itertools.product(dtypes, modes, trainable_scale):
|
for mode, trainable in itertools.product(modes, trainable_scale):
|
||||||
test_per_sample_weights(mode, dtype, trainable)
|
test_per_sample_weights(mode, trainable)
|
||||||
|
|
||||||
def test_EmbeddingBag_per_sample_weights_and_offsets(self, device):
|
@dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.float, torch.double, torch.half)))
|
||||||
def test_per_sample_weights(mode, dtype, trainable_scale):
|
@dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double)))
|
||||||
es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtype, device=device)
|
def test_EmbeddingBag_per_sample_weights_and_offsets(self, device, dtypes):
|
||||||
|
def test_per_sample_weights(mode, trainable_scale):
|
||||||
|
es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtypes[1], device=device)
|
||||||
es.weight.data.copy_(
|
es.weight.data.copy_(
|
||||||
torch.arange(1, 11, device=device, dtype=dtype).view_as(es.weight))
|
torch.arange(1, 11, device=device, dtype=dtypes[1]).view_as(es.weight))
|
||||||
input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=torch.long)
|
input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtypes[0])
|
||||||
offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=torch.long)
|
offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=dtypes[0])
|
||||||
per_sample_weights = torch.randn_like(input, dtype=dtype) \
|
per_sample_weights = torch.randn_like(input, dtype=dtypes[1]) \
|
||||||
.requires_grad_(trainable_scale)
|
.requires_grad_(trainable_scale)
|
||||||
ref_per_sample_weights = \
|
ref_per_sample_weights = \
|
||||||
per_sample_weights.detach().requires_grad_(trainable_scale)
|
per_sample_weights.detach().requires_grad_(trainable_scale)
|
||||||
|
|
@ -11773,39 +11776,37 @@ class TestNNDeviceType(NNTestCase):
|
||||||
expected = self._embedding_bag_reference_impl(
|
expected = self._embedding_bag_reference_impl(
|
||||||
input, reference_weights, offsets, mode, ref_per_sample_weights)
|
input, reference_weights, offsets, mode, ref_per_sample_weights)
|
||||||
result = es(input, offsets, per_sample_weights)
|
result = es(input, offsets, per_sample_weights)
|
||||||
self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtype], rtol=0)
|
self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0)
|
||||||
|
|
||||||
grad = torch.randn_like(expected)
|
grad = torch.randn_like(expected)
|
||||||
result.backward(grad)
|
result.backward(grad)
|
||||||
expected.backward(grad)
|
expected.backward(grad)
|
||||||
self.assertEqual(es.weight.grad, reference_weights.grad,
|
self.assertEqual(es.weight.grad, reference_weights.grad,
|
||||||
atol=dtype2prec_DONTUSE[dtype], rtol=0)
|
atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0)
|
||||||
if trainable_scale:
|
if trainable_scale:
|
||||||
self.assertEqual(per_sample_weights.grad, ref_per_sample_weights.grad,
|
self.assertEqual(per_sample_weights.grad, ref_per_sample_weights.grad,
|
||||||
atol=dtype2prec_DONTUSE[dtype], rtol=0)
|
atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0)
|
||||||
|
|
||||||
if device == 'cuda':
|
|
||||||
dtypes = (torch.float, torch.double, torch.half)
|
|
||||||
else:
|
|
||||||
dtypes = (torch.float, torch.double)
|
|
||||||
modes = ('sum',)
|
modes = ('sum',)
|
||||||
trainable_scale = (True, False)
|
trainable_scale = (True, False)
|
||||||
for dtype, mode, trainable in itertools.product(dtypes, modes, trainable_scale):
|
for mode, trainable in itertools.product(modes, trainable_scale):
|
||||||
test_per_sample_weights(mode, dtype, trainable)
|
test_per_sample_weights(mode, trainable)
|
||||||
|
|
||||||
def test_EmbeddingBag_per_sample_weights_and_new_offsets(self, device):
|
@dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.float, torch.double, torch.half)))
|
||||||
def test_per_sample_weights_new_offsets(mode, dtype, trainable_scale, include_last_offset, has_weight=True):
|
@dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double)))
|
||||||
es = nn.EmbeddingBag(5, 2, mode=mode, include_last_offset=include_last_offset).to(dtype=dtype, device=device)
|
def test_EmbeddingBag_per_sample_weights_and_new_offsets(self, device, dtypes):
|
||||||
|
def test_per_sample_weights_new_offsets(mode, trainable_scale, include_last_offset, has_weight=True):
|
||||||
|
es = nn.EmbeddingBag(5, 2, mode=mode, include_last_offset=include_last_offset).to(dtype=dtypes[1], device=device)
|
||||||
es.weight.data.copy_(
|
es.weight.data.copy_(
|
||||||
torch.arange(1, 11, device=device, dtype=dtype).view_as(es.weight))
|
torch.arange(1, 11, device=device, dtype=dtypes[1]).view_as(es.weight))
|
||||||
input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=torch.long)
|
input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtypes[0])
|
||||||
offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=torch.long)
|
offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=dtypes[0])
|
||||||
|
|
||||||
if include_last_offset:
|
if include_last_offset:
|
||||||
offsets = torch.cat((offsets, torch.tensor([input.size(0)], device=device, dtype=torch.long)), 0)
|
offsets = torch.cat((offsets, torch.tensor([input.size(0)], device=device, dtype=dtypes[0])), 0)
|
||||||
|
|
||||||
if has_weight:
|
if has_weight:
|
||||||
per_sample_weights = torch.randn_like(input, device=device, dtype=dtype) \
|
per_sample_weights = torch.randn_like(input, device=device, dtype=dtypes[1]) \
|
||||||
.requires_grad_(trainable_scale)
|
.requires_grad_(trainable_scale)
|
||||||
ref_per_sample_weights = \
|
ref_per_sample_weights = \
|
||||||
per_sample_weights.detach().requires_grad_(trainable_scale)
|
per_sample_weights.detach().requires_grad_(trainable_scale)
|
||||||
|
|
@ -11818,51 +11819,48 @@ class TestNNDeviceType(NNTestCase):
|
||||||
expected = self._embedding_bag_reference_impl(
|
expected = self._embedding_bag_reference_impl(
|
||||||
input, reference_weights, offsets, mode, ref_per_sample_weights, include_last_offset)
|
input, reference_weights, offsets, mode, ref_per_sample_weights, include_last_offset)
|
||||||
result = es(input, offsets, per_sample_weights)
|
result = es(input, offsets, per_sample_weights)
|
||||||
self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtype], rtol=0)
|
self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0)
|
||||||
|
|
||||||
grad = torch.randn_like(expected)
|
grad = torch.randn_like(expected)
|
||||||
result.backward(grad)
|
result.backward(grad)
|
||||||
expected.backward(grad)
|
expected.backward(grad)
|
||||||
self.assertEqual(es.weight.grad, reference_weights.grad,
|
self.assertEqual(es.weight.grad, reference_weights.grad,
|
||||||
atol=dtype2prec_DONTUSE[dtype], rtol=0)
|
atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0)
|
||||||
if has_weight and trainable_scale:
|
if has_weight and trainable_scale:
|
||||||
self.assertEqual(per_sample_weights.grad, ref_per_sample_weights.grad,
|
self.assertEqual(per_sample_weights.grad, ref_per_sample_weights.grad,
|
||||||
atol=dtype2prec_DONTUSE[dtype], rtol=0)
|
atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0)
|
||||||
|
|
||||||
if device == 'cuda':
|
|
||||||
dtypes = (torch.float, torch.double, torch.half)
|
|
||||||
else:
|
|
||||||
dtypes = (torch.float, torch.double)
|
|
||||||
trainable_scale = (True, False)
|
trainable_scale = (True, False)
|
||||||
include_last_offset = (True, False)
|
include_last_offset = (True, False)
|
||||||
modes = (('sum', False), ('sum', True), ('max', False), ('mean', False))
|
modes = (('sum', False), ('sum', True), ('max', False), ('mean', False))
|
||||||
for dtype, (mode, has_weight), trainable, include_last_offset in itertools.product(
|
for (mode, has_weight), trainable, include_last_offset in itertools.product(
|
||||||
dtypes, modes, trainable_scale, include_last_offset
|
modes, trainable_scale, include_last_offset
|
||||||
):
|
):
|
||||||
test_per_sample_weights_new_offsets(
|
test_per_sample_weights_new_offsets(
|
||||||
mode, dtype, trainable, include_last_offset, has_weight
|
mode, trainable, include_last_offset, has_weight
|
||||||
)
|
)
|
||||||
|
|
||||||
def _test_EmbeddingBag_vs_Embedding(self, N, D, B, L, max_norm=None,
|
def _test_EmbeddingBag_vs_Embedding(self, N, D, B, L, max_norm=None,
|
||||||
mode='mean',
|
mode='mean',
|
||||||
device='cpu',
|
device='cpu',
|
||||||
dtype=torch.float,
|
wdtype=torch.float,
|
||||||
|
dtype=torch.long,
|
||||||
test_per_sample_weights=False,
|
test_per_sample_weights=False,
|
||||||
trainable_per_sample_weights=False,
|
trainable_per_sample_weights=False,
|
||||||
sparse=False,
|
sparse=False,
|
||||||
test_backward=True,
|
test_backward=True,
|
||||||
backward_prec=None):
|
backward_prec=None):
|
||||||
es = nn.EmbeddingBag(N, D, mode=mode, sparse=sparse, max_norm=max_norm).to(device, dtype)
|
es = nn.EmbeddingBag(N, D, mode=mode, sparse=sparse, max_norm=max_norm).to(device, wdtype)
|
||||||
e = nn.Embedding(N, D, max_norm=max_norm).to(device, dtype)
|
e = nn.Embedding(N, D, max_norm=max_norm).to(device, wdtype)
|
||||||
e.weight.data.copy_(es.weight)
|
e.weight.data.copy_(es.weight)
|
||||||
input = torch.randint(N, (B, L), device=device, dtype=torch.long)
|
input = torch.randint(N, (B, L), device=device, dtype=dtype)
|
||||||
offsets = torch.arange(0, B, device=device, dtype=torch.long).mul_(L)
|
offsets = torch.arange(0, B, device=device, dtype=dtype).mul_(L)
|
||||||
grad_output = torch.rand(B, D, device=device, dtype=dtype)
|
grad_output = torch.rand(B, D, device=device, dtype=wdtype)
|
||||||
|
|
||||||
if test_per_sample_weights:
|
if test_per_sample_weights:
|
||||||
# To prevent large gradients, weights should sum to 1 for each bag
|
# To prevent large gradients, weights should sum to 1 for each bag
|
||||||
per_sample_weights = \
|
per_sample_weights = \
|
||||||
torch.randn(B, L, device=device, dtype=dtype).softmax(dim=-1)
|
torch.randn(B, L, device=device, dtype=wdtype).softmax(dim=-1)
|
||||||
per_sample_weights_reference = \
|
per_sample_weights_reference = \
|
||||||
per_sample_weights.clone().requires_grad_(trainable_per_sample_weights)
|
per_sample_weights.clone().requires_grad_(trainable_per_sample_weights)
|
||||||
per_sample_weights.requires_grad_(trainable_per_sample_weights)
|
per_sample_weights.requires_grad_(trainable_per_sample_weights)
|
||||||
|
|
@ -11884,7 +11882,7 @@ class TestNNDeviceType(NNTestCase):
|
||||||
assert not test_per_sample_weights
|
assert not test_per_sample_weights
|
||||||
ref_output = e(input).max(1)[0]
|
ref_output = e(input).max(1)[0]
|
||||||
|
|
||||||
self.assertEqual(output, ref_output, atol=dtype2prec_DONTUSE[dtype], rtol=0)
|
self.assertEqual(output, ref_output, atol=dtype2prec_DONTUSE[wdtype], rtol=0)
|
||||||
|
|
||||||
if not test_backward:
|
if not test_backward:
|
||||||
return
|
return
|
||||||
|
|
@ -11897,7 +11895,7 @@ class TestNNDeviceType(NNTestCase):
|
||||||
|
|
||||||
# We have more floating point error here because we are dealing with larger numbers
|
# We have more floating point error here because we are dealing with larger numbers
|
||||||
if backward_prec is None:
|
if backward_prec is None:
|
||||||
needed_prec = dtype2prec_DONTUSE[dtype] * 3
|
needed_prec = dtype2prec_DONTUSE[wdtype] * 3
|
||||||
else:
|
else:
|
||||||
needed_prec = backward_prec
|
needed_prec = backward_prec
|
||||||
|
|
||||||
|
|
@ -11905,13 +11903,15 @@ class TestNNDeviceType(NNTestCase):
|
||||||
|
|
||||||
if test_per_sample_weights and trainable_per_sample_weights:
|
if test_per_sample_weights and trainable_per_sample_weights:
|
||||||
self.assertEqual(per_sample_weights.grad, per_sample_weights_reference.grad,
|
self.assertEqual(per_sample_weights.grad, per_sample_weights_reference.grad,
|
||||||
atol=dtype2prec_DONTUSE[dtype], rtol=0)
|
atol=dtype2prec_DONTUSE[wdtype], rtol=0)
|
||||||
|
|
||||||
@skipCUDAIf(True, "Temporarily disabled. See t54369166")
|
@skipCUDAIf(True, "Temporarily disabled. See t54369166")
|
||||||
def test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device):
|
@dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.half, torch.float, torch.double)))
|
||||||
def run_tests(dtype, mode, sparse, trainable_per_sample_weights):
|
@dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double)))
|
||||||
|
def test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device, dtypes):
|
||||||
|
def run_tests(mode, sparse, trainable_per_sample_weights):
|
||||||
kwargs = dict(test_per_sample_weights=True, device=device,
|
kwargs = dict(test_per_sample_weights=True, device=device,
|
||||||
mode=mode, dtype=dtype, sparse=sparse,
|
mode=mode, wdtype=dtypes[1], dtype=dtypes[0], sparse=sparse,
|
||||||
trainable_per_sample_weights=trainable_per_sample_weights)
|
trainable_per_sample_weights=trainable_per_sample_weights)
|
||||||
|
|
||||||
# Simple case
|
# Simple case
|
||||||
|
|
@ -11926,78 +11926,76 @@ class TestNNDeviceType(NNTestCase):
|
||||||
# Large embedding_dim
|
# Large embedding_dim
|
||||||
self._test_EmbeddingBag_vs_Embedding(2, 101, 3, 7, **kwargs)
|
self._test_EmbeddingBag_vs_Embedding(2, 101, 3, 7, **kwargs)
|
||||||
|
|
||||||
dtypes = (torch.float, torch.double)
|
|
||||||
modes = ('sum',)
|
modes = ('sum',)
|
||||||
sparsity = (True, False)
|
sparsity = (True, False)
|
||||||
trainable_scale = (True, False)
|
trainable_scale = (True, False)
|
||||||
for dtype, mode, sparse, trainable_per_sample_weights in \
|
for mode, sparse, trainable_per_sample_weights in \
|
||||||
itertools.product(dtypes, modes, sparsity, trainable_scale):
|
itertools.product(modes, sparsity, trainable_scale):
|
||||||
run_tests(dtype, mode, sparse, trainable_per_sample_weights)
|
run_tests(mode, sparse, trainable_per_sample_weights)
|
||||||
|
|
||||||
# Test CUDA Dense on half precision
|
# Test CUDA Dense on half precision
|
||||||
if device == 'cuda':
|
if device == 'cuda':
|
||||||
dtypes = (torch.half,)
|
|
||||||
modes = ('sum',)
|
modes = ('sum',)
|
||||||
sparsity = (False,)
|
sparsity = (False,)
|
||||||
trainable_scale = (True, False)
|
trainable_scale = (True, False)
|
||||||
for dtype, mode, sparse, trainable_per_sample_weights in \
|
for mode, sparse, trainable_per_sample_weights in \
|
||||||
itertools.product(dtypes, modes, sparsity, trainable_scale):
|
itertools.product(modes, sparsity, trainable_scale):
|
||||||
run_tests(dtype, mode, sparse, trainable_per_sample_weights)
|
run_tests(mode, sparse, trainable_per_sample_weights)
|
||||||
|
|
||||||
def _test_EmbeddingBag(self, device, mode, sparse, dtype=torch.double, test_backward=True):
|
def _test_EmbeddingBag(self, device, mode, sparse, wdtype=torch.double, dtype=torch.long, test_backward=True):
|
||||||
# check a known test example
|
# check a known test example
|
||||||
es = nn.EmbeddingBag(5, 2, mode=mode, sparse=sparse).to(device, dtype)
|
es = nn.EmbeddingBag(5, 2, mode=mode, sparse=sparse).to(device, wdtype)
|
||||||
es.weight.data.copy_(torch.arange(1, 11, device=device, dtype=dtype).view_as(es.weight))
|
es.weight.data.copy_(torch.arange(1, 11, device=device, dtype=wdtype).view_as(es.weight))
|
||||||
input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=torch.long)
|
input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtype)
|
||||||
offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=torch.long)
|
offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=dtype)
|
||||||
|
|
||||||
grad_output = torch.tensor(
|
grad_output = torch.tensor(
|
||||||
[1, 2,
|
[1, 2,
|
||||||
3, 4], device=device, dtype=dtype).view(2, 2)
|
3, 4], device=device, dtype=wdtype).view(2, 2)
|
||||||
grad_output_with_empty = torch.tensor(
|
grad_output_with_empty = torch.tensor(
|
||||||
[99, 99,
|
[99, 99,
|
||||||
1, 2,
|
1, 2,
|
||||||
99, 99,
|
99, 99,
|
||||||
3, 4,
|
3, 4,
|
||||||
99, 99], device=device, dtype=dtype).view(5, 2)
|
99, 99], device=device, dtype=wdtype).view(5, 2)
|
||||||
|
|
||||||
if mode == "sum" or mode == "mean":
|
if mode == "sum" or mode == "mean":
|
||||||
denominator = 1 if mode == "sum" else 3
|
denominator = 1 if mode == "sum" else 3
|
||||||
expected_output = torch.tensor(
|
expected_output = torch.tensor(
|
||||||
[[13, 16],
|
[[13, 16],
|
||||||
[13, 16]], device=device, dtype=dtype) / denominator
|
[13, 16]], device=device, dtype=wdtype) / denominator
|
||||||
|
|
||||||
expected_output_with_empty = torch.tensor(
|
expected_output_with_empty = torch.tensor(
|
||||||
[[0, 0],
|
[[0, 0],
|
||||||
[13, 16],
|
[13, 16],
|
||||||
[0, 0],
|
[0, 0],
|
||||||
[13, 16],
|
[13, 16],
|
||||||
[0, 0]], device=device, dtype=dtype) / denominator
|
[0, 0]], device=device, dtype=wdtype) / denominator
|
||||||
|
|
||||||
expected_grad_weight = torch.tensor(
|
expected_grad_weight = torch.tensor(
|
||||||
[[3, 4],
|
[[3, 4],
|
||||||
[5, 8],
|
[5, 8],
|
||||||
[0, 0],
|
[0, 0],
|
||||||
[1, 2],
|
[1, 2],
|
||||||
[3, 4]], device=device, dtype=dtype) / denominator
|
[3, 4]], device=device, dtype=wdtype) / denominator
|
||||||
elif mode == "max":
|
elif mode == "max":
|
||||||
expected_output = torch.tensor(
|
expected_output = torch.tensor(
|
||||||
[[7, 8],
|
[[7, 8],
|
||||||
[9, 10]], device=device, dtype=dtype)
|
[9, 10]], device=device, dtype=wdtype)
|
||||||
|
|
||||||
expected_output_with_empty = torch.tensor(
|
expected_output_with_empty = torch.tensor(
|
||||||
[[0, 0],
|
[[0, 0],
|
||||||
[7, 8],
|
[7, 8],
|
||||||
[0, 0],
|
[0, 0],
|
||||||
[9, 10],
|
[9, 10],
|
||||||
[0, 0]], device=device, dtype=dtype)
|
[0, 0]], device=device, dtype=wdtype)
|
||||||
|
|
||||||
expected_grad_weight = torch.tensor(
|
expected_grad_weight = torch.tensor(
|
||||||
[[0, 0],
|
[[0, 0],
|
||||||
[0, 0],
|
[0, 0],
|
||||||
[0, 0],
|
[0, 0],
|
||||||
[1, 2],
|
[1, 2],
|
||||||
[3, 4]], device=device, dtype=dtype)
|
[3, 4]], device=device, dtype=wdtype)
|
||||||
output = es(input, offsets)
|
output = es(input, offsets)
|
||||||
output.backward(grad_output_with_empty)
|
output.backward(grad_output_with_empty)
|
||||||
|
|
||||||
|
|
@ -12005,7 +12003,7 @@ class TestNNDeviceType(NNTestCase):
|
||||||
if sparse:
|
if sparse:
|
||||||
es_weight_grad = es.weight.grad.to_dense()
|
es_weight_grad = es.weight.grad.to_dense()
|
||||||
self.assertEqual(output, expected_output_with_empty)
|
self.assertEqual(output, expected_output_with_empty)
|
||||||
self.assertEqual(es_weight_grad, expected_grad_weight, atol=dtype2prec_DONTUSE[dtype], rtol=0)
|
self.assertEqual(es_weight_grad, expected_grad_weight, atol=dtype2prec_DONTUSE[wdtype], rtol=0)
|
||||||
|
|
||||||
# check same example except as 2D (2 x 3)
|
# check same example except as 2D (2 x 3)
|
||||||
input = input.view(2, -1)
|
input = input.view(2, -1)
|
||||||
|
|
@ -12017,12 +12015,12 @@ class TestNNDeviceType(NNTestCase):
|
||||||
if sparse:
|
if sparse:
|
||||||
es_weight_grad = es.weight.grad.to_dense()
|
es_weight_grad = es.weight.grad.to_dense()
|
||||||
self.assertEqual(output, expected_output)
|
self.assertEqual(output, expected_output)
|
||||||
self.assertEqual(es_weight_grad, expected_grad_weight, atol=dtype2prec_DONTUSE[dtype], rtol=0)
|
self.assertEqual(es_weight_grad, expected_grad_weight, atol=dtype2prec_DONTUSE[wdtype], rtol=0)
|
||||||
|
|
||||||
# test all empty bags
|
# test all empty bags
|
||||||
es.zero_grad()
|
es.zero_grad()
|
||||||
inputs = torch.tensor([], dtype=torch.long, device=device)
|
inputs = torch.tensor([], dtype=dtype, device=device)
|
||||||
offsets = torch.tensor([0, 0, 0, 0], device=device)
|
offsets = torch.tensor([0, 0, 0, 0], dtype=dtype, device=device)
|
||||||
es(inputs, offsets).sum().backward()
|
es(inputs, offsets).sum().backward()
|
||||||
dense_grad = es.weight.grad
|
dense_grad = es.weight.grad
|
||||||
if dense_grad.is_sparse:
|
if dense_grad.is_sparse:
|
||||||
|
|
@ -12031,7 +12029,7 @@ class TestNNDeviceType(NNTestCase):
|
||||||
|
|
||||||
# now compare EmbeddingBag vs Embedding + Sum/Mean, for constant bag length
|
# now compare EmbeddingBag vs Embedding + Sum/Mean, for constant bag length
|
||||||
N, D, B, L = random.randint(1, 100), random.randint(1, 100), random.randint(1, 50), random.randint(1, 50)
|
N, D, B, L = random.randint(1, 100), random.randint(1, 100), random.randint(1, 50), random.randint(1, 50)
|
||||||
kwargs = dict(mode=mode, sparse=sparse, device=device, dtype=dtype, test_backward=test_backward)
|
kwargs = dict(mode=mode, sparse=sparse, device=device, wdtype=wdtype, dtype=dtype, test_backward=test_backward)
|
||||||
self._test_EmbeddingBag_vs_Embedding(N, D, B, L, **kwargs)
|
self._test_EmbeddingBag_vs_Embedding(N, D, B, L, **kwargs)
|
||||||
for max_norm in (None, 3):
|
for max_norm in (None, 3):
|
||||||
for p in itertools.product([1, 2], repeat=4):
|
for p in itertools.product([1, 2], repeat=4):
|
||||||
|
|
@ -12039,8 +12037,8 @@ class TestNNDeviceType(NNTestCase):
|
||||||
|
|
||||||
# check that giving illegal input combos raises error
|
# check that giving illegal input combos raises error
|
||||||
es = nn.EmbeddingBag(10, 20, mode=mode, sparse=sparse)
|
es = nn.EmbeddingBag(10, 20, mode=mode, sparse=sparse)
|
||||||
input = torch.ones(3, 4, dtype=torch.long)
|
input = torch.ones(3, 4, dtype=dtype)
|
||||||
offset = torch.arange(0, 3)
|
offset = torch.arange(0, 3, dtype=dtype)
|
||||||
self.assertRaises(ValueError, lambda: es(input, offset))
|
self.assertRaises(ValueError, lambda: es(input, offset))
|
||||||
self.assertRaises(ValueError, lambda: es(input.view(-1)))
|
self.assertRaises(ValueError, lambda: es(input.view(-1)))
|
||||||
offset[0] = 1
|
offset[0] = 1
|
||||||
|
|
@ -12050,35 +12048,35 @@ class TestNNDeviceType(NNTestCase):
|
||||||
offset[-1] = 100
|
offset[-1] = 100
|
||||||
self.assertRaises(RuntimeError, lambda: es(input.view(-1), offset))
|
self.assertRaises(RuntimeError, lambda: es(input.view(-1), offset))
|
||||||
|
|
||||||
@dtypesIfCUDA(torch.half, torch.float, torch.double)
|
@dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.float, torch.double, torch.half)))
|
||||||
@dtypes(torch.float, torch.double)
|
@dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double)))
|
||||||
def test_embedding_bag_device(self, device, dtype):
|
def test_embedding_bag_device(self, device, dtypes):
|
||||||
self._test_EmbeddingBag(device, 'sum', False, dtype)
|
self._test_EmbeddingBag(device, 'sum', False, wdtype=dtypes[1], dtype=dtypes[0])
|
||||||
self._test_EmbeddingBag(device, 'mean', False, dtype)
|
self._test_EmbeddingBag(device, 'mean', False, wdtype=dtypes[1], dtype=dtypes[0])
|
||||||
self._test_EmbeddingBag(device, 'max', False, dtype)
|
self._test_EmbeddingBag(device, 'max', False, wdtype=dtypes[1], dtype=dtypes[0])
|
||||||
|
|
||||||
test_backward = False
|
test_backward = False
|
||||||
if self.device_type == 'cuda':
|
if self.device_type == 'cuda':
|
||||||
# see 'todo' in test_embedding_bag.
|
# see 'todo' in test_embedding_bag.
|
||||||
test_backward = dtype is not torch.float16
|
test_backward = dtypes[1] is not torch.float16
|
||||||
elif self.device_type == 'cpu':
|
elif self.device_type == 'cpu':
|
||||||
# TODO: figure out why precision on sparse embeddings isn't the
|
# TODO: figure out why precision on sparse embeddings isn't the
|
||||||
# same as for dense.
|
# same as for dense.
|
||||||
test_backward = dtype is not torch.float
|
test_backward = dtypes[1] is not torch.float
|
||||||
|
|
||||||
self._test_EmbeddingBag(device, 'sum', True, dtype, test_backward=test_backward)
|
self._test_EmbeddingBag(device, 'sum', True, wdtype=dtypes[1], dtype=dtypes[0], test_backward=test_backward)
|
||||||
self._test_EmbeddingBag(device, 'mean', True, dtype, test_backward=test_backward)
|
self._test_EmbeddingBag(device, 'mean', True, wdtype=dtypes[1], dtype=dtypes[0], test_backward=test_backward)
|
||||||
|
|
||||||
@dtypesIfCUDA(torch.half, torch.float, torch.double)
|
@dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.float, torch.double, torch.half)))
|
||||||
@dtypes(torch.float, torch.double)
|
@dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double)))
|
||||||
def test_embedding_bag_non_contiguous_weight(self, device, dtype):
|
def test_embedding_bag_non_contiguous_weight(self, device, dtypes):
|
||||||
weight_tensor = torch.randn(3, 4, dtype=dtype, device=device)
|
weight_tensor = torch.randn(3, 4, dtype=dtypes[1], device=device)
|
||||||
|
|
||||||
weight_tensor_non_contig = weight_tensor[:, :3] # This is non-contiguous strided.
|
weight_tensor_non_contig = weight_tensor[:, :3] # This is non-contiguous strided.
|
||||||
weight_tensor_contig = weight_tensor_non_contig.clone().contiguous() # Contig-strided.
|
weight_tensor_contig = weight_tensor_non_contig.clone().contiguous() # Contig-strided.
|
||||||
|
|
||||||
index = torch.tensor([0, 1, 2], device=device)
|
index = torch.tensor([0, 1, 2], dtype=dtypes[0], device=device)
|
||||||
offsets = torch.tensor([0, 2], device=device)
|
offsets = torch.tensor([0, 2], dtype=dtypes[0], device=device)
|
||||||
for mode in ['sum', 'mean', 'max']:
|
for mode in ['sum', 'mean', 'max']:
|
||||||
output_non_contig = F.embedding_bag(
|
output_non_contig = F.embedding_bag(
|
||||||
input=index,
|
input=index,
|
||||||
|
|
@ -12097,9 +12095,10 @@ class TestNNDeviceType(NNTestCase):
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
@skipCUDAIfNotRocm
|
@skipCUDAIfNotRocm
|
||||||
def test_embedding_bag_bfloat16(self, device):
|
@dtypes(torch.int, torch.long)
|
||||||
self._test_EmbeddingBag(device, 'sum', True, dtype=torch.bfloat16, test_backward=True)
|
def test_embedding_bag_bfloat16(self, device, dtype):
|
||||||
self._test_EmbeddingBag(device, 'mean', True, dtype=torch.bfloat16, test_backward=True)
|
self._test_EmbeddingBag(device, 'sum', True, wdtype=torch.bfloat16, dtype=dtype, test_backward=True)
|
||||||
|
self._test_EmbeddingBag(device, 'mean', True, wdtype=torch.bfloat16, dtype=dtype, test_backward=True)
|
||||||
|
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
|
|
|
||||||
|
|
@ -1618,16 +1618,18 @@ class AbstractTestCases:
|
||||||
reference[0.0, :, 0.0] = 1
|
reference[0.0, :, 0.0] = 1
|
||||||
|
|
||||||
def test_index_add(self):
|
def test_index_add(self):
|
||||||
|
for device in torch.testing.get_all_device_types():
|
||||||
for dest_contig, src_contig, index_contig in product([True, False], repeat=3):
|
for dest_contig, src_contig, index_contig in product([True, False], repeat=3):
|
||||||
for other_sizes in ((), (4, 5)):
|
for other_sizes in ((), (4, 5)):
|
||||||
|
for dtype in [torch.int, torch.long]:
|
||||||
num_copy, num_dest = 3, 3
|
num_copy, num_dest = 3, 3
|
||||||
dest = torch.randn(num_dest, *other_sizes)
|
dest = torch.randn(num_dest, *other_sizes, device=device)
|
||||||
if not dest_contig:
|
if not dest_contig:
|
||||||
dest = torch.testing.make_non_contiguous(dest)
|
dest = torch.testing.make_non_contiguous(dest)
|
||||||
src = torch.randn(num_copy, *other_sizes)
|
src = torch.randn(num_copy, *other_sizes, device=device)
|
||||||
if not src_contig:
|
if not src_contig:
|
||||||
src = torch.testing.make_non_contiguous(src)
|
src = torch.testing.make_non_contiguous(src)
|
||||||
idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
|
idx = torch.randperm(num_dest, dtype=dtype, device=device).narrow(0, 0, num_copy)
|
||||||
if not index_contig:
|
if not index_contig:
|
||||||
idx = torch.testing.make_non_contiguous(idx)
|
idx = torch.testing.make_non_contiguous(idx)
|
||||||
dest2 = dest.clone()
|
dest2 = dest.clone()
|
||||||
|
|
@ -1642,6 +1644,7 @@ class AbstractTestCases:
|
||||||
def test_index_add_all_dtypes(self):
|
def test_index_add_all_dtypes(self):
|
||||||
for device in torch.testing.get_all_device_types():
|
for device in torch.testing.get_all_device_types():
|
||||||
for dtype in torch.testing.get_all_math_dtypes(device):
|
for dtype in torch.testing.get_all_math_dtypes(device):
|
||||||
|
for idx_dtype in [torch.int, torch.long]:
|
||||||
size = [5, 5]
|
size = [5, 5]
|
||||||
if dtype.is_floating_point or dtype.is_complex:
|
if dtype.is_floating_point or dtype.is_complex:
|
||||||
tensor = torch.rand(size, dtype=dtype, device=device)
|
tensor = torch.rand(size, dtype=dtype, device=device)
|
||||||
|
|
@ -1657,7 +1660,7 @@ class AbstractTestCases:
|
||||||
if device.startswith('cuda') and dtype.is_complex:
|
if device.startswith('cuda') and dtype.is_complex:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
added = zeros.index_add(0, torch.arange(0, size[0], dtype=torch.long, device=device), tensor)
|
added = zeros.index_add(0, torch.arange(0, size[0], dtype=idx_dtype, device=device), tensor)
|
||||||
self.assertEqual(added, tensor)
|
self.assertEqual(added, tensor)
|
||||||
|
|
||||||
def test_t(self):
|
def test_t(self):
|
||||||
|
|
@ -12735,9 +12738,10 @@ class TestTorchDeviceType(TestCase):
|
||||||
self.assertEqual(x, torch.tensor([[0, 2], [0, 5]], dtype=dt, device=device))
|
self.assertEqual(x, torch.tensor([[0, 2], [0, 5]], dtype=dt, device=device))
|
||||||
|
|
||||||
def test_index_select(self, device):
|
def test_index_select(self, device):
|
||||||
|
for dtype in [torch.int, torch.long]:
|
||||||
src = torch.randn(3, 4, 5, device=device)
|
src = torch.randn(3, 4, 5, device=device)
|
||||||
# Index can be duplicated.
|
# Index can be duplicated.
|
||||||
idx = torch.tensor([2, 1, 0, 1, 2], dtype=torch.long, device=device)
|
idx = torch.tensor([2, 1, 0, 1, 2], dtype=dtype, device=device)
|
||||||
dest = torch.index_select(src, 0, idx)
|
dest = torch.index_select(src, 0, idx)
|
||||||
self.assertEqual(dest.shape, (5, 4, 5))
|
self.assertEqual(dest.shape, (5, 4, 5))
|
||||||
for i in range(idx.size(0)):
|
for i in range(idx.size(0)):
|
||||||
|
|
@ -12754,13 +12758,13 @@ class TestTorchDeviceType(TestCase):
|
||||||
|
|
||||||
# Bool tensor
|
# Bool tensor
|
||||||
src = torch.tensor([False, True, False, False], device=device, dtype=torch.bool)
|
src = torch.tensor([False, True, False, False], device=device, dtype=torch.bool)
|
||||||
idx = torch.tensor([1], dtype=torch.long, device=device)
|
idx = torch.tensor([1], dtype=dtype, device=device)
|
||||||
dest = torch.index_select(src, 0, idx)
|
dest = torch.index_select(src, 0, idx)
|
||||||
self.assertEqual(torch.tensor([True]), dest)
|
self.assertEqual(torch.tensor([True]), dest)
|
||||||
|
|
||||||
# Complex Tensor
|
# Complex Tensor
|
||||||
src = torch.randn(3, 4, 5, dtype=torch.complex64, device=device)
|
src = torch.randn(3, 4, 5, dtype=torch.complex64, device=device)
|
||||||
idx = torch.tensor([2, 1, 0, 1, 2], dtype=torch.long, device=device)
|
idx = torch.tensor([2, 1, 0, 1, 2], dtype=dtype, device=device)
|
||||||
dest = torch.index_select(src, 0, idx)
|
dest = torch.index_select(src, 0, idx)
|
||||||
self.assertEqual(dest.shape, (5, 4, 5))
|
self.assertEqual(dest.shape, (5, 4, 5))
|
||||||
for i in range(idx.size(0)):
|
for i in range(idx.size(0)):
|
||||||
|
|
|
||||||
|
|
@ -1678,7 +1678,7 @@ Note:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dim (int): dimension along which to index
|
dim (int): dimension along which to index
|
||||||
index (LongTensor): indices of :attr:`tensor` to select from
|
index (IntTensor or LongTensor): indices of :attr:`tensor` to select from
|
||||||
tensor (Tensor): the tensor containing values to add
|
tensor (Tensor): the tensor containing values to add
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
|
||||||
|
|
@ -3410,7 +3410,7 @@ of :attr:`index`; other dimensions have the same size as in the original tensor.
|
||||||
Args:
|
Args:
|
||||||
{input}
|
{input}
|
||||||
dim (int): the dimension in which we index
|
dim (int): the dimension in which we index
|
||||||
index (LongTensor): the 1-D tensor containing the indices to index
|
index (IntTensor or LongTensor): the 1-D tensor containing the indices to index
|
||||||
|
|
||||||
Keyword args:
|
Keyword args:
|
||||||
{out}
|
{out}
|
||||||
|
|
|
||||||
|
|
@ -1940,7 +1940,7 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,
|
||||||
" fixed length sequences. However, found "
|
" fixed length sequences. However, found "
|
||||||
"offsets of type {}".format(type_str))
|
"offsets of type {}".format(type_str))
|
||||||
offsets = torch.arange(0, input.numel(), input.size(1),
|
offsets = torch.arange(0, input.numel(), input.size(1),
|
||||||
dtype=torch.long, device=input.device)
|
dtype=input.dtype, device=input.device)
|
||||||
|
|
||||||
input = input.reshape(-1)
|
input = input.reshape(-1)
|
||||||
if per_sample_weights is not None:
|
if per_sample_weights is not None:
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ class Embedding(Module):
|
||||||
initialized from :math:`\mathcal{N}(0, 1)`
|
initialized from :math:`\mathcal{N}(0, 1)`
|
||||||
|
|
||||||
Shape:
|
Shape:
|
||||||
- Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract
|
- Input: :math:`(*)`, IntTensor or LongTensor of arbitrary shape containing the indices to extract
|
||||||
- Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
|
- Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
@ -246,9 +246,11 @@ class EmbeddingBag(Module):
|
||||||
weight (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)`
|
weight (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)`
|
||||||
initialized from :math:`\mathcal{N}(0, 1)`.
|
initialized from :math:`\mathcal{N}(0, 1)`.
|
||||||
|
|
||||||
Inputs: :attr:`input` (LongTensor), :attr:`offsets` (LongTensor, optional), and
|
Inputs: :attr:`input` (IntTensor or LongTensor), :attr:`offsets` (IntTensor or LongTensor, optional), and
|
||||||
:attr:`per_index_weights` (Tensor, optional)
|
:attr:`per_index_weights` (Tensor, optional)
|
||||||
|
|
||||||
|
- :attr:`input` and :attr:`offsets` have to be of the same type, either int or long
|
||||||
|
|
||||||
- If :attr:`input` is 2D of shape `(B, N)`,
|
- If :attr:`input` is 2D of shape `(B, N)`,
|
||||||
|
|
||||||
it will be treated as ``B`` bags (sequences) each of fixed length ``N``, and
|
it will be treated as ``B`` bags (sequences) each of fixed length ``N``, and
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user