mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Mark bincount CUDA deterministic if weights are not given (#105244)
Fixes #98316 Pull Request resolved: https://github.com/pytorch/pytorch/pull/105244 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
parent
e9fd815226
commit
fcb7d4b358
|
|
@ -366,9 +366,12 @@ Tensor _bincount_cuda(
|
|||
c10::MaybeOwned<Tensor> weights_maybe_owned = at::borrow_from_optional_tensor(weights_opt);
|
||||
const Tensor& weights = *weights_maybe_owned;
|
||||
|
||||
// See Note [Writing Nondeterministic Operations]
|
||||
// Nondeterministic because of atomicAdd usage
|
||||
globalContext().alertNotDeterministic("_bincount_cuda");
|
||||
if (weights_opt.has_value()) {
|
||||
// See Note [Writing Nondeterministic Operations]
|
||||
// Nondeterministic if weights are given, because of floating point
|
||||
// atomicAdd usage
|
||||
globalContext().alertNotDeterministic("_bincount_cuda");
|
||||
}
|
||||
return AT_DISPATCH_INTEGRAL_TYPES(self.scalar_type(), "bincount_cuda", [&] {
|
||||
const auto scalar = weights.scalar_type();
|
||||
if (scalar == ScalarType::Undefined || scalar == ScalarType::Float)
|
||||
|
|
|
|||
|
|
@ -1746,11 +1746,20 @@ else:
|
|||
@skipIfMps
|
||||
def test_nondeterministic_alert_bincount(self, device):
|
||||
a = torch.tensor([], device=device, dtype=torch.long)
|
||||
weights = torch.tensor([], device=device)
|
||||
|
||||
for op_call in [torch.bincount, torch.Tensor.bincount]:
|
||||
# Error should only be raised when device is CUDA and weights are
|
||||
# given
|
||||
self.check_nondeterministic_alert(
|
||||
lambda: op_call(a, weights),
|
||||
'_bincount_cuda',
|
||||
torch.device(device).type == 'cuda')
|
||||
|
||||
self.check_nondeterministic_alert(
|
||||
lambda: op_call(a),
|
||||
'_bincount_cuda',
|
||||
torch.device(device).type == 'cuda')
|
||||
False)
|
||||
|
||||
# Ensures that kthvalue throws nondeterministic alerts in the correct cases
|
||||
@dtypes(torch.double)
|
||||
|
|
|
|||
|
|
@ -743,7 +743,8 @@ def use_deterministic_algorithms(mode, *, warn_only=False):
|
|||
* :func:`torch.Tensor.put_` when ``accumulate=False``
|
||||
* :func:`torch.Tensor.put_` when ``accumulate=True`` and called on a CUDA tensor
|
||||
* :func:`torch.histc` when called on a CUDA tensor
|
||||
* :func:`torch.bincount` when called on a CUDA tensor
|
||||
* :func:`torch.bincount` when called on a CUDA tensor and ``weights``
|
||||
tensor is given
|
||||
* :func:`torch.kthvalue` with called on a CUDA tensor
|
||||
* :func:`torch.median` with indices output when called on a CUDA tensor
|
||||
* :func:`torch.nn.functional.grid_sample` when attempting to differentiate a CUDA tensor
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user