Fix race condition and make CUDA kthvalue deterministic (#165762)

The gatherKthValue kernel had a race condition where multiple threads could write to the same output location without synchronization when duplicate k-th values exist, resulting in non-deterministic output.

Changes:
- aten/src/ATen/native/cuda/Sorting.cu: Use atomicMin with shared memory to deterministically find minimum index. Add early termination and remove redundant inRange checks. (We have to cast the index to `int32_t`, but this is already assumed to fit earlier in the kernel.)
- aten/src/ATen/native/cuda/Sorting.cpp: Remove non-deterministic alert since kthvalue is now deterministic on CUDA.
- torch/__init__.py: Remove kthvalue from non-deterministic operations list and remove kthvalue example from use_deterministic_algorithms() docstring.
- test/test_torch.py: Remove test_nondeterministic_alert_kthvalue since kthvalue no longer raises alerts on CUDA.

Benefits:
- Deterministic: always returns minimum index when duplicates exist
- Potential performance improvement on large arrays with repetitions

Test Results:
- All existing PyTorch tests pass (test_kthvalue)
- Custom determinism tests confirm consistent results
- Custom CUDA vs CPU correctness validated across 50+ scenarios
- Custom performance benchmarks show improvements with no visible regressions

Addresses #165227

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165762
Approved by: https://github.com/ngimel, https://github.com/eqy
This commit is contained in:
nick-kuhn 2025-10-25 00:45:53 +00:00 committed by PyTorch MergeBot
parent 9d0b77f4cd
commit 79a4a9c02e
4 changed files with 30 additions and 49 deletions

View File

@ -43,6 +43,12 @@ std::tuple<Tensor&, Tensor&> kthvalue_out_impl_cuda(
TORCH_CHECK(k >= 1 && k <= slicesize,
"kthvalue(): selected number k out of range for dimension ", dim);
TORCH_CHECK(
slicesize <= std::numeric_limits<int32_t>::max(),
"kthvalue(): dimension ", dim, " is too large (", slicesize,
"). The current CUDA implementation supports dimension sizes up to ",
std::numeric_limits<int32_t>::max());
at::assert_no_overlap(self, values);
_reduction_with_indices_allocate_or_resize_output(
@ -163,10 +169,6 @@ std::tuple<Tensor&, Tensor&> kthvalue_out_cuda(
bool keepdim,
Tensor& values,
Tensor& indices) {
// See note [Writing Nondeterministic Operations]
// If there are duplicate elements of the kth value, the procedure for choosing which
// of the duplicates to use for the indices output is nondeterministic.
at::globalContext().alertNotDeterministic("kthvalue CUDA");
auto result = [&]() {
NoNamesGuard guard;
// `kthvalue_out_impl_cuda` expects contiguous in input `self`.

View File

@ -65,25 +65,34 @@ __global__ void gatherKthValue(
&kValue);
// Find the index of the k-th highest element
index_t kValueIndex = 0;
bool foundKValue = false;
__shared__ int32_t minIndexFound;
if (threadIdx.x == 0) {
minIndexFound = static_cast<int32_t>(inputSliceSize);
}
__syncthreads();
for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) {
bool inRange = (i < inputSliceSize);
scalar_t v = inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride])
: static_cast<scalar_t>(0);
bool isKValue = inRange &&
((v == kValue) || (at::_isnan(v) && at::_isnan(kValue)));
if (isKValue) {
kValueIndex = i;
foundKValue = true;
break;
}
// Early exit based on best-so-far
if (i >= minIndexFound) {
break;
}
scalar_t v = doLdg(&inputSliceStart[i * inputWithinSliceStride]);
bool isKValue =
((v == kValue) || (at::_isnan(v) && at::_isnan(kValue)));
if (isKValue) {
atomicMin(&minIndexFound, static_cast<int32_t>(i));
break;
}
}
if (foundKValue) {
kthValueSliceStart[0] = kValue;
indicesSliceStart[0] = kValueIndex;
__syncthreads();
if (threadIdx.x == 0) {
indicesSliceStart[0] = static_cast<index_t>(minIndexFound);
kthValueSliceStart[0] = kValue;
}
}

View File

@ -1837,30 +1837,6 @@ class TestTorchDeviceType(TestCase):
'_bincount_cuda',
False)
# Ensures that kthvalue throws nondeterministic alerts in the correct cases
@dtypes(torch.double)
def test_nondeterministic_alert_kthvalue(self, device, dtype):
def test_func(call_type):
S = 10
k = 5
a = torch.randn(S, device=device)
if call_type == 'function':
torch.kthvalue(a, k)
elif call_type == 'method':
a.kthvalue(k)
elif call_type == 'out':
values = torch.empty_like(a)
indices = torch.empty((), device=device, dtype=torch.long)
torch.kthvalue(a, k, out=(values, indices))
else:
self.fail(f"'{call_type}' is not a valid call type")
for call_type in ['function', 'method', 'out']:
self.check_nondeterministic_alert(
lambda: test_func('function'),
'kthvalue CUDA',
torch.device(device).type == 'cuda')
@skipIfMPS
@skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
def test_nondeterministic_alert_grid_sample_2d(self, device):

View File

@ -1412,7 +1412,6 @@ def use_deterministic_algorithms(
* :func:`torch.histc` when called on a CUDA tensor
* :func:`torch.bincount` when called on a CUDA tensor and ``weights``
tensor is given
* :func:`torch.kthvalue` with called on a CUDA tensor
* :func:`torch.median` with indices output when called on a CUDA tensor
* :func:`torch.nn.functional.grid_sample` when attempting to differentiate a CUDA tensor
* :func:`torch.cumsum` when called on a CUDA tensor when dtype is floating point or complex
@ -1471,11 +1470,6 @@ def use_deterministic_algorithms(
>>> # xdoctest: +SKIP
>>> torch.use_deterministic_algorithms(True)
# Forward mode nondeterministic error
>>> torch.randn(10, device='cuda').kthvalue(1)
...
RuntimeError: kthvalue CUDA does not have a deterministic implementation...
# Backward mode nondeterministic error
>>> torch.nn.AvgPool3d(1)(torch.randn(3, 4, 5, 6, requires_grad=True).cuda()).sum().backward()
...