Mark bincount CUDA deterministic if weights are not given (#105244)

Fixes #98316

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105244
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
Kurt Mohler 2023-07-18 01:16:47 +00:00 committed by PyTorch MergeBot
parent e9fd815226
commit fcb7d4b358
3 changed files with 18 additions and 5 deletions

View File

@ -366,9 +366,12 @@ Tensor _bincount_cuda(
c10::MaybeOwned<Tensor> weights_maybe_owned = at::borrow_from_optional_tensor(weights_opt);
const Tensor& weights = *weights_maybe_owned;
// See Note [Writing Nondeterministic Operations]
// Nondeterministic because of atomicAdd usage
globalContext().alertNotDeterministic("_bincount_cuda");
if (weights_opt.has_value()) {
// See Note [Writing Nondeterministic Operations]
// Nondeterministic if weights are given, because of floating point
// atomicAdd usage
globalContext().alertNotDeterministic("_bincount_cuda");
}
return AT_DISPATCH_INTEGRAL_TYPES(self.scalar_type(), "bincount_cuda", [&] {
const auto scalar = weights.scalar_type();
if (scalar == ScalarType::Undefined || scalar == ScalarType::Float)

View File

@ -1746,11 +1746,20 @@ else:
@skipIfMps
def test_nondeterministic_alert_bincount(self, device):
a = torch.tensor([], device=device, dtype=torch.long)
weights = torch.tensor([], device=device)
for op_call in [torch.bincount, torch.Tensor.bincount]:
# Error should only be raised when device is CUDA and weights are
# given
self.check_nondeterministic_alert(
lambda: op_call(a, weights),
'_bincount_cuda',
torch.device(device).type == 'cuda')
self.check_nondeterministic_alert(
lambda: op_call(a),
'_bincount_cuda',
torch.device(device).type == 'cuda')
False)
# Ensures that kthvalue throws nondeterministic alerts in the correct cases
@dtypes(torch.double)

View File

@ -743,7 +743,8 @@ def use_deterministic_algorithms(mode, *, warn_only=False):
* :func:`torch.Tensor.put_` when ``accumulate=False``
* :func:`torch.Tensor.put_` when ``accumulate=True`` and called on a CUDA tensor
* :func:`torch.histc` when called on a CUDA tensor
* :func:`torch.bincount` when called on a CUDA tensor
* :func:`torch.bincount` when called on a CUDA tensor and ``weights``
tensor is given
* :func:`torch.kthvalue` with called on a CUDA tensor
* :func:`torch.median` with indices output when called on a CUDA tensor
* :func:`torch.nn.functional.grid_sample` when attempting to differentiate a CUDA tensor