Add range check embedding_bag on input index >= 0 of cuda device (#140791)

Fixes #89362

**Test Result**

**Before**

```
>>> import torch
>>> input = torch.randint(-5, 1, [1, 2], dtype=torch.int64).cuda()
>>> weight = torch.rand([2, 3], dtype=torch.float32).cuda()
>>> print(torch.nn.functional.embedding_bag(input, weight))
tensor([[0., 0., 0.]], device='cuda:0')
```

**After**

```python
>>> import torch
>>> input = torch.randint(-5, 1, [1, 2], dtype=torch.int64).cuda()
>>> weight = torch.rand([2, 3], dtype=torch.float32).cuda()
>>> print(torch.nn.functional.embedding_bag(input, weight))
/home/zong/code/pytorch/aten/src/ATen/native/cuda/EmbeddingBag.cu:141: EmbeddingBag_updateOutputKernel_sum_mean: block: [0,0,0], thread: [0,0,0] Assertion `0 <= input_idx && input_idx < numRows` failed.
/home/zong/code/pytorch/aten/src/ATen/native/cuda/EmbeddingBag.cu:141: EmbeddingBag_updateOutputKernel_sum_mean: block: [0,0,0], thread: [1,0,0] Assertion `0 <= input_idx && input_idx < numRows` failed.
/home/zong/code/pytorch/aten/src/ATen/native/cuda/EmbeddingBag.cu:141: EmbeddingBag_updateOutputKernel_sum_mean: block: [0,0,0], thread: [2,0,0] Assertion `0 <= input_idx && input_idx < numRows` failed.
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/zong/code/pytorch/torch/_tensor.py", line 568, in __repr__
    return torch._tensor_str._str(self, tensor_contents=tensor_contents)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/torch/_tensor_str.py", line 708, in _str
    return _str_intern(self, tensor_contents=tensor_contents)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/torch/_tensor_str.py", line 625, in _str_intern
    tensor_str = _tensor_str(self, indent)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/torch/_tensor_str.py", line 357, in _tensor_str
    formatter = _Formatter(get_summarized_data(self) if summarize else self)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/torch/_tensor_str.py", line 146, in __init__
    tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

```

```bash
$ pytest test/nn/test_embedding.py
```
![image](https://github.com/user-attachments/assets/6a5ec759-a3dc-4d51-9e5e-ec79c0aac526)

```bash
$ lintrunner
```
![image](https://github.com/user-attachments/assets/2ce4ac24-74fb-4181-9510-18b96a2c2acb)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140791
Approved by: https://github.com/eqy
This commit is contained in:
zeshengzong 2024-12-20 05:47:24 +00:00 committed by PyTorch MergeBot
parent 9713a6eeca
commit 217a4ddb04

View File

@ -136,9 +136,10 @@ __global__ void EmbeddingBag_updateOutputKernel_sum_mean(
accscalar_t weightFeatSum = 0;
int64_t bag_size_ = 0;
for (int64_t emb = begin; emb < end; emb++) {
bool pad = (input[emb] == padding_idx);
CUDA_KERNEL_ASSERT(input[emb] < numRows);
const int64_t weightRow = input[emb] * weight_stride0;
index_t input_idx = input[emb];
bool pad = (input_idx == padding_idx);
CUDA_KERNEL_ASSERT(0 <= input_idx && input_idx < numRows);
const int64_t weightRow = input_idx * weight_stride0;
scalar_t weightValue = weightFeat[weightRow];
weightValue = pad ? static_cast<scalar_t>(0) : weightValue;
if (per_sample_weights) {