mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add range check embedding_bag on input index >= 0 of cuda device (#140791)
Fixes #89362 **Test Result** **Before** ``` >>> import torch >>> input = torch.randint(-5, 1, [1, 2], dtype=torch.int64).cuda() >>> weight = torch.rand([2, 3], dtype=torch.float32).cuda() >>> print(torch.nn.functional.embedding_bag(input, weight)) tensor([[0., 0., 0.]], device='cuda:0') ``` **After** ```python >>> import torch >>> input = torch.randint(-5, 1, [1, 2], dtype=torch.int64).cuda() >>> weight = torch.rand([2, 3], dtype=torch.float32).cuda() >>> print(torch.nn.functional.embedding_bag(input, weight)) /home/zong/code/pytorch/aten/src/ATen/native/cuda/EmbeddingBag.cu:141: EmbeddingBag_updateOutputKernel_sum_mean: block: [0,0,0], thread: [0,0,0] Assertion `0 <= input_idx && input_idx < numRows` failed. /home/zong/code/pytorch/aten/src/ATen/native/cuda/EmbeddingBag.cu:141: EmbeddingBag_updateOutputKernel_sum_mean: block: [0,0,0], thread: [1,0,0] Assertion `0 <= input_idx && input_idx < numRows` failed. /home/zong/code/pytorch/aten/src/ATen/native/cuda/EmbeddingBag.cu:141: EmbeddingBag_updateOutputKernel_sum_mean: block: [0,0,0], thread: [2,0,0] Assertion `0 <= input_idx && input_idx < numRows` failed. Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/home/zong/code/pytorch/torch/_tensor.py", line 568, in __repr__ return torch._tensor_str._str(self, tensor_contents=tensor_contents) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/zong/code/pytorch/torch/_tensor_str.py", line 708, in _str return _str_intern(self, tensor_contents=tensor_contents) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/zong/code/pytorch/torch/_tensor_str.py", line 625, in _str_intern tensor_str = _tensor_str(self, indent) ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/zong/code/pytorch/torch/_tensor_str.py", line 357, in _tensor_str formatter = _Formatter(get_summarized_data(self) if summarize else self) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/zong/code/pytorch/torch/_tensor_str.py", line 146, in __init__ tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: CUDA error: device-side assert triggered CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1 Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions. ``` ```bash $ pytest test/nn/test_embedding.py ```  ```bash $ lintrunner ```  Pull Request resolved: https://github.com/pytorch/pytorch/pull/140791 Approved by: https://github.com/eqy
This commit is contained in:
parent
9713a6eeca
commit
217a4ddb04
|
|
@ -136,9 +136,10 @@ __global__ void EmbeddingBag_updateOutputKernel_sum_mean(
|
||||||
accscalar_t weightFeatSum = 0;
|
accscalar_t weightFeatSum = 0;
|
||||||
int64_t bag_size_ = 0;
|
int64_t bag_size_ = 0;
|
||||||
for (int64_t emb = begin; emb < end; emb++) {
|
for (int64_t emb = begin; emb < end; emb++) {
|
||||||
bool pad = (input[emb] == padding_idx);
|
index_t input_idx = input[emb];
|
||||||
CUDA_KERNEL_ASSERT(input[emb] < numRows);
|
bool pad = (input_idx == padding_idx);
|
||||||
const int64_t weightRow = input[emb] * weight_stride0;
|
CUDA_KERNEL_ASSERT(0 <= input_idx && input_idx < numRows);
|
||||||
|
const int64_t weightRow = input_idx * weight_stride0;
|
||||||
scalar_t weightValue = weightFeat[weightRow];
|
scalar_t weightValue = weightFeat[weightRow];
|
||||||
weightValue = pad ? static_cast<scalar_t>(0) : weightValue;
|
weightValue = pad ? static_cast<scalar_t>(0) : weightValue;
|
||||||
if (per_sample_weights) {
|
if (per_sample_weights) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user