mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
fix for launching kernel invalid config error when calling embedding … (#130994)
…with large index
Fixes #130806
When an output size of 2147483648 (=131072*16384) is expected in the above issue, it throwed out the following error:
RuntimeError: HIP error: invalid configuration argument
What happened was that the second parameter passed to hipLaunchKernel was crazy {2147483648,1,1}.
Found two issues in the Indexing.cu:
1: ptrdiff_t was used but it is signed int, outTotalSize >= 2147483648 can cause overflow when doing [this](39493aa934/aten/src/ATen/native/cuda/Indexing.cu (L1367)):
2: On ROCm, std::min -> ::min did not work as expected when outTotalSize>=2147483648
As the result, 2147483648 was sent to hipLaunchKernel which the GPU does not support such a huge number since this number specifies the number of threads per block. The original code intended to set 128 threads per block, though this is debatable as the perf would not good for latest powerful GPUs (a TODO item to update for perf maybe?) , but at least it would not cause `invalid configuration argument` error.
[Test]
Run the same code snippet in the [issue](https://github.com/pytorch/pytorch/issues/130806), and print the output, its dim and numel(), which looks like below now:
```
output=tensor([[ 0.4044, -0.0244, -0.6865, ..., -0.7800, 0.1175, 1.6726],
[-1.0866, -0.1609, 0.3538, ..., 1.9105, 0.7882, 1.1583],
[-2.2079, 0.3736, 0.3610, ..., -0.2658, -0.0459, 1.3077],
...,
[ 0.8753, -0.7482, -0.1978, ..., 0.9016, 1.1501, -0.5178],
[-1.5845, -0.6277, 1.4520, ..., 0.5733, -2.1198, -0.0915],
[-0.6310, -1.0239, -0.1910, ..., 0.4309, 0.1630, 0.3239]],
device='cuda:0'), dim=2, numel=2147483648
```
Added a large tensor unit test too.
```
/pytorch# pytest test/nn/test_embedding.py -k test_large_tensors
================================================================================== test session starts ===================================================================================
platform linux -- Python 3.9.19, pytest-7.3.2, pluggy-1.4.0
rootdir: /dockerx/development/pytorch
configfile: pytest.ini
plugins: flakefinder-1.1.0, rerunfailures-14.0, xdist-3.3.1, xdoctest-1.1.0, cpp-2.3.0, hypothesis-5.35.1
collected 288 items / 287 deselected / 1 selected
Running 1 items in this shard
test/nn/test_embedding.py . [100%]
=========================================================================== 1 passed, 287 deselected in 3.16s ============================================================================
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130994
Approved by: https://github.com/jeffdaily, https://github.com/xw285cornell
This commit is contained in:
parent
a8319698b3
commit
637ab85e7f
|
|
@ -688,7 +688,7 @@ REGISTER_CUDA_DISPATCH(index_put_with_sort_quantized_stub, &index_put_with_sort_
|
|||
|
||||
|
||||
// Check tensor dimensions for index operations, and return the slice size.
|
||||
static ptrdiff_t getSliceSize(const Tensor & dst,
|
||||
static size_t getSliceSize(const Tensor & dst,
|
||||
int dim,
|
||||
const Tensor & index,
|
||||
const Tensor & src)
|
||||
|
|
@ -698,7 +698,7 @@ static ptrdiff_t getSliceSize(const Tensor & dst,
|
|||
|
||||
TORCH_CHECK(index.dim() <= 1, "Index must be vector or scalar");
|
||||
|
||||
ptrdiff_t dstSliceSize = 1;
|
||||
size_t dstSliceSize = 1;
|
||||
TORCH_CHECK(dim >= 0 && dim < dstDims, "Indexing dim ", dim, " is out of bounds");
|
||||
for (const auto d: c10::irange(dstDims)) {
|
||||
if (d != dim) {
|
||||
|
|
@ -710,7 +710,7 @@ static ptrdiff_t getSliceSize(const Tensor & dst,
|
|||
TORCH_CHECK(index.numel() == src.size(dim),
|
||||
"length of src.size[dim] is not equal to length of indices");
|
||||
|
||||
ptrdiff_t srcSliceSize = 1;
|
||||
size_t srcSliceSize = 1;
|
||||
bool mismatch = false;
|
||||
|
||||
if (dstDims != srcDims) mismatch = true;
|
||||
|
|
@ -900,11 +900,11 @@ void index_add_cuda_impl(const Tensor& self, int64_t dim, const Tensor& index, c
|
|||
// total size of the tensor ignoring dimension `dim`;
|
||||
// -the number of index we are choosing, which is the total size
|
||||
// of the tensor `index`.
|
||||
const ptrdiff_t sliceSize = getSliceSize(self_, dim, index, source_);
|
||||
const ptrdiff_t sourceTotalSize = source.numel();
|
||||
const int64_t selfAddDimSize = self_.size(dim);
|
||||
const ptrdiff_t numIndex = index.numel();
|
||||
const int64_t selfNumel = self_.numel();
|
||||
const uint64_t sliceSize = getSliceSize(self_, dim, index, source_);
|
||||
const uint64_t sourceTotalSize = source.numel();
|
||||
const uint64_t selfAddDimSize = self_.size(dim);
|
||||
const uint64_t numIndex = index.numel();
|
||||
const uint64_t selfNumel = self_.numel();
|
||||
|
||||
if (sliceSize == 0) {
|
||||
return;
|
||||
|
|
@ -933,11 +933,11 @@ void index_add_cuda_impl(const Tensor& self, int64_t dim, const Tensor& index, c
|
|||
selfAddDimSize, selfNumel, reduce_add, alpha_value); \
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
|
||||
const dim3 smallIndexGrid(std::min(ceil_div(sliceSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8)));
|
||||
const dim3 smallIndexBlock(std::min(sliceSize, (ptrdiff_t)128));
|
||||
const dim3 smallIndexGrid(std::min(ceil_div(sliceSize, (uint64_t)128), (uint64_t)(mpc * 8)));
|
||||
const dim3 smallIndexBlock(std::min(sliceSize, (uint64_t)128));
|
||||
|
||||
const dim3 largeIndexGrid(std::min(ceil_div(sourceTotalSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8)));
|
||||
const dim3 largeIndexBlock(std::min(sourceTotalSize, (ptrdiff_t)128));
|
||||
const dim3 largeIndexGrid(std::min(ceil_div(sourceTotalSize, (uint64_t)128), (uint64_t)(mpc * 8)));
|
||||
const dim3 largeIndexBlock(std::min(sourceTotalSize, (uint64_t)128));
|
||||
|
||||
if (cuda::detail::canUse32BitIndexMath(result) &&
|
||||
cuda::detail::canUse32BitIndexMath(source) &&
|
||||
|
|
@ -1073,11 +1073,11 @@ void index_reduce_func_cuda_impl(
|
|||
// total size of the tensor ignoring dimension `dim`;
|
||||
// -the number of index we are choosing, which is the total size
|
||||
// of the tensor `index`.
|
||||
ptrdiff_t sliceSize = getSliceSize(self_, dim, index, source_);
|
||||
ptrdiff_t sourceTotalSize = source.numel();
|
||||
int64_t selfReduceDimSize = self_.size(dim);
|
||||
ptrdiff_t numIndex = index.numel();
|
||||
int64_t selfNumel = self_.numel();
|
||||
uint64_t sliceSize = getSliceSize(self_, dim, index, source_);
|
||||
uint64_t sourceTotalSize = source.numel();
|
||||
uint64_t selfReduceDimSize = self_.size(dim);
|
||||
uint64_t numIndex = index.numel();
|
||||
uint64_t selfNumel = self_.numel();
|
||||
|
||||
if (sliceSize == 0) {
|
||||
return;
|
||||
|
|
@ -1106,11 +1106,11 @@ void index_reduce_func_cuda_impl(
|
|||
selfReduceDimSize, selfNumel, reduce_func, alpha_value); \
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
|
||||
dim3 smallIndexGrid(std::min(ceil_div(sliceSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8)));
|
||||
dim3 smallIndexBlock(std::min(sliceSize, (ptrdiff_t)128));
|
||||
dim3 smallIndexGrid(std::min(ceil_div(sliceSize, (uint64_t)128), (uint64_t)(mpc * 8)));
|
||||
dim3 smallIndexBlock(std::min(sliceSize, (uint64_t)128));
|
||||
|
||||
dim3 largeIndexGrid(std::min(ceil_div(sourceTotalSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8)));
|
||||
dim3 largeIndexBlock(std::min(sourceTotalSize, (ptrdiff_t)128));
|
||||
dim3 largeIndexGrid(std::min(ceil_div(sourceTotalSize, (uint64_t)128), (uint64_t)(mpc * 8)));
|
||||
dim3 largeIndexBlock(std::min(sourceTotalSize, (uint64_t)128));
|
||||
|
||||
if (cuda::detail::canUse32BitIndexMath(result) &&
|
||||
cuda::detail::canUse32BitIndexMath(source) &&
|
||||
|
|
@ -1342,8 +1342,8 @@ void index_select_out_cuda_impl(
|
|||
const Tensor& self,
|
||||
long dim,
|
||||
const Tensor& index) {
|
||||
ptrdiff_t numIndices = index.numel();
|
||||
int selfDims = self.dim() == 0 ? 1 : self.dim();
|
||||
uint64_t numIndices = index.numel();
|
||||
uint64_t selfDims = self.dim() == 0 ? 1 : self.dim();
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
|
|
@ -1364,7 +1364,7 @@ void index_select_out_cuda_impl(
|
|||
at::native::resize_output(out, newSize);
|
||||
}
|
||||
|
||||
ptrdiff_t outTotalSize = out.numel();
|
||||
uint64_t outTotalSize = out.numel();
|
||||
if (outTotalSize == 0) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -1376,8 +1376,8 @@ void index_select_out_cuda_impl(
|
|||
// total size of the tensor ignoring dimension `dim`;
|
||||
// -the number of indices we are choosing, which is the total size
|
||||
// of the tensor `indices`.
|
||||
int64_t selfSelectDimSize = self.dim() == 0 ? 1 : self.size(dim);
|
||||
ptrdiff_t sliceSize = outTotalSize / numIndices;
|
||||
uint64_t selfSelectDimSize = self.dim() == 0 ? 1 : self.size(dim);
|
||||
uint64_t sliceSize = outTotalSize / numIndices;
|
||||
|
||||
int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
|
||||
|
||||
|
|
@ -1400,11 +1400,14 @@ void index_select_out_cuda_impl(
|
|||
selfSelectDimSize); \
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
|
||||
dim3 smallIndexGrid(std::min(ceil_div(sliceSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8)));
|
||||
dim3 smallIndexBlock(std::min(sliceSize, (ptrdiff_t)128));
|
||||
dim3 smallIndexGrid(std::min(ceil_div(sliceSize, (uint64_t)128), (uint64_t) (mpc * 8)));
|
||||
dim3 smallIndexBlock(std::min(sliceSize, (uint64_t)128));
|
||||
|
||||
dim3 largeIndexGrid(std::min(ceil_div(outTotalSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8)));
|
||||
dim3 largeIndexBlock(std::min(outTotalSize, (ptrdiff_t)128));
|
||||
dim3 largeIndexGrid(std::min(ceil_div(outTotalSize, (uint64_t)128), (uint64_t) (mpc * 8)));
|
||||
// for issue https://github.com/pytorch/pytorch/issues/130806 there are two problems
|
||||
// 1: ptrdiff_t was used but it is signed int, outTotalSize of 2147483648 can cause overflow
|
||||
// 2: On ROCm, std::min -> ::min did not work as expected on when outTotalSize>=2147483648
|
||||
dim3 largeIndexBlock( (outTotalSize < 128) ? outTotalSize : 128 );
|
||||
if (cuda::detail::canUse32BitIndexMath(out) &&
|
||||
cuda::detail::canUse32BitIndexMath(self) &&
|
||||
cuda::detail::canUse32BitIndexMath(index)) {
|
||||
|
|
|
|||
|
|
@ -12,12 +12,14 @@ from torch.testing._internal.common_device_type import (
|
|||
dtypes,
|
||||
dtypesIfCUDA,
|
||||
instantiate_device_type_tests,
|
||||
largeTensorTest,
|
||||
onlyCUDA,
|
||||
onlyNativeDeviceTypes,
|
||||
skipCUDAIf,
|
||||
skipMeta,
|
||||
TEST_WITH_ROCM,
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_nn import NNTestCase
|
||||
from torch.testing._internal.common_utils import (
|
||||
_assertGradAndGradgradChecks,
|
||||
|
|
@ -180,6 +182,15 @@ class TestEmbeddingNN(NNTestCase):
|
|||
|
||||
self.assertEqual(res_old, res_F)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/130806
|
||||
@largeTensorTest("40GB", device="cuda")
|
||||
def test_large_tensors(self):
|
||||
input = torch.randint(low=0, high=16032, size=[131072], device="cuda")
|
||||
w = torch.randn([16032, 16384], device="cuda")
|
||||
out = torch.nn.functional.embedding(input, w)
|
||||
self.assertEqual(out.dim(), 2)
|
||||
self.assertEqual(out.numel(), 2147483648)
|
||||
|
||||
def test_embedding_bag_functional(self):
|
||||
a = torch.tensor([[1, 3, 2], [0, 2, 1]], dtype=torch.long)
|
||||
embeddings = torch.rand(4, 3, requires_grad=True)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user