mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Fix race condition and make CUDA kthvalue deterministic (#165762)
The gatherKthValue kernel had a race condition where multiple threads could write to the same output location without synchronization when duplicate k-th values exist, resulting in non-deterministic output. Changes: - aten/src/ATen/native/cuda/Sorting.cu: Use atomicMin with shared memory to deterministically find minimum index. Add early termination and remove redundant inRange checks. (We have to cast the index to `int32_t`, but this is already assumed to fit earlier in the kernel.) - aten/src/ATen/native/cuda/Sorting.cpp: Remove non-deterministic alert since kthvalue is now deterministic on CUDA. - torch/__init__.py: Remove kthvalue from non-deterministic operations list and remove kthvalue example from use_deterministic_algorithms() docstring. - test/test_torch.py: Remove test_nondeterministic_alert_kthvalue since kthvalue no longer raises alerts on CUDA. Benefits: - Deterministic: always returns minimum index when duplicates exist - Potential performance improvement on large arrays with repetitions Test Results: - All existing PyTorch tests pass (test_kthvalue) - Custom determinism tests confirm consistent results - Custom CUDA vs CPU correctness validated across 50+ scenarios - Custom performance benchmarks show improvements with no visible regressions Addresses #165227 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165762 Approved by: https://github.com/ngimel, https://github.com/eqy
This commit is contained in:
parent
9d0b77f4cd
commit
79a4a9c02e
|
|
@ -43,6 +43,12 @@ std::tuple<Tensor&, Tensor&> kthvalue_out_impl_cuda(
|
|||
TORCH_CHECK(k >= 1 && k <= slicesize,
|
||||
"kthvalue(): selected number k out of range for dimension ", dim);
|
||||
|
||||
TORCH_CHECK(
|
||||
slicesize <= std::numeric_limits<int32_t>::max(),
|
||||
"kthvalue(): dimension ", dim, " is too large (", slicesize,
|
||||
"). The current CUDA implementation supports dimension sizes up to ",
|
||||
std::numeric_limits<int32_t>::max());
|
||||
|
||||
at::assert_no_overlap(self, values);
|
||||
|
||||
_reduction_with_indices_allocate_or_resize_output(
|
||||
|
|
@ -163,10 +169,6 @@ std::tuple<Tensor&, Tensor&> kthvalue_out_cuda(
|
|||
bool keepdim,
|
||||
Tensor& values,
|
||||
Tensor& indices) {
|
||||
// See note [Writing Nondeterministic Operations]
|
||||
// If there are duplicate elements of the kth value, the procedure for choosing which
|
||||
// of the duplicates to use for the indices output is nondeterministic.
|
||||
at::globalContext().alertNotDeterministic("kthvalue CUDA");
|
||||
auto result = [&]() {
|
||||
NoNamesGuard guard;
|
||||
// `kthvalue_out_impl_cuda` expects contiguous in input `self`.
|
||||
|
|
|
|||
|
|
@ -65,25 +65,34 @@ __global__ void gatherKthValue(
|
|||
&kValue);
|
||||
|
||||
// Find the index of the k-th highest element
|
||||
index_t kValueIndex = 0;
|
||||
bool foundKValue = false;
|
||||
__shared__ int32_t minIndexFound;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
minIndexFound = static_cast<int32_t>(inputSliceSize);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) {
|
||||
bool inRange = (i < inputSliceSize);
|
||||
scalar_t v = inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride])
|
||||
: static_cast<scalar_t>(0);
|
||||
bool isKValue = inRange &&
|
||||
((v == kValue) || (at::_isnan(v) && at::_isnan(kValue)));
|
||||
if (isKValue) {
|
||||
kValueIndex = i;
|
||||
foundKValue = true;
|
||||
break;
|
||||
}
|
||||
// Early exit based on best-so-far
|
||||
if (i >= minIndexFound) {
|
||||
break;
|
||||
}
|
||||
|
||||
scalar_t v = doLdg(&inputSliceStart[i * inputWithinSliceStride]);
|
||||
bool isKValue =
|
||||
((v == kValue) || (at::_isnan(v) && at::_isnan(kValue)));
|
||||
|
||||
if (isKValue) {
|
||||
atomicMin(&minIndexFound, static_cast<int32_t>(i));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (foundKValue) {
|
||||
kthValueSliceStart[0] = kValue;
|
||||
indicesSliceStart[0] = kValueIndex;
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
indicesSliceStart[0] = static_cast<index_t>(minIndexFound);
|
||||
kthValueSliceStart[0] = kValue;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1837,30 +1837,6 @@ class TestTorchDeviceType(TestCase):
|
|||
'_bincount_cuda',
|
||||
False)
|
||||
|
||||
# Ensures that kthvalue throws nondeterministic alerts in the correct cases
|
||||
@dtypes(torch.double)
|
||||
def test_nondeterministic_alert_kthvalue(self, device, dtype):
|
||||
def test_func(call_type):
|
||||
S = 10
|
||||
k = 5
|
||||
a = torch.randn(S, device=device)
|
||||
if call_type == 'function':
|
||||
torch.kthvalue(a, k)
|
||||
elif call_type == 'method':
|
||||
a.kthvalue(k)
|
||||
elif call_type == 'out':
|
||||
values = torch.empty_like(a)
|
||||
indices = torch.empty((), device=device, dtype=torch.long)
|
||||
torch.kthvalue(a, k, out=(values, indices))
|
||||
else:
|
||||
self.fail(f"'{call_type}' is not a valid call type")
|
||||
|
||||
for call_type in ['function', 'method', 'out']:
|
||||
self.check_nondeterministic_alert(
|
||||
lambda: test_func('function'),
|
||||
'kthvalue CUDA',
|
||||
torch.device(device).type == 'cuda')
|
||||
|
||||
@skipIfMPS
|
||||
@skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
|
||||
def test_nondeterministic_alert_grid_sample_2d(self, device):
|
||||
|
|
|
|||
|
|
@ -1412,7 +1412,6 @@ def use_deterministic_algorithms(
|
|||
* :func:`torch.histc` when called on a CUDA tensor
|
||||
* :func:`torch.bincount` when called on a CUDA tensor and ``weights``
|
||||
tensor is given
|
||||
* :func:`torch.kthvalue` with called on a CUDA tensor
|
||||
* :func:`torch.median` with indices output when called on a CUDA tensor
|
||||
* :func:`torch.nn.functional.grid_sample` when attempting to differentiate a CUDA tensor
|
||||
* :func:`torch.cumsum` when called on a CUDA tensor when dtype is floating point or complex
|
||||
|
|
@ -1471,11 +1470,6 @@ def use_deterministic_algorithms(
|
|||
>>> # xdoctest: +SKIP
|
||||
>>> torch.use_deterministic_algorithms(True)
|
||||
|
||||
# Forward mode nondeterministic error
|
||||
>>> torch.randn(10, device='cuda').kthvalue(1)
|
||||
...
|
||||
RuntimeError: kthvalue CUDA does not have a deterministic implementation...
|
||||
|
||||
# Backward mode nondeterministic error
|
||||
>>> torch.nn.AvgPool3d(1)(torch.randn(3, 4, 5, 6, requires_grad=True).cuda()).sum().backward()
|
||||
...
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user