mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Add deterministic path for CUDA cumsum (#136224)"
This reverts commit d1bb8e828f.
Reverted https://github.com/pytorch/pytorch/pull/136224 on behalf of https://github.com/atalman due to Break internal CI ([comment](https://github.com/pytorch/pytorch/pull/136224#issuecomment-2379214226))
This commit is contained in:
parent
c2637a7b26
commit
e9d2765ec8
|
|
@ -1739,33 +1739,11 @@ else:
|
|||
'embedding_bag_backward_cuda_max',
|
||||
torch.device(device).type == 'cuda')
|
||||
|
||||
@skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
|
||||
@onlyCUDA
|
||||
def test_deterministic_cumsum(self, device):
|
||||
test_cases = [
|
||||
# size, dim
|
||||
[(2, 3, 4), 0],
|
||||
[(2, 3, 4), 1],
|
||||
[(2, 3, 4), 2],
|
||||
[(1000, 10, 2), 0],
|
||||
]
|
||||
for size, dim in test_cases:
|
||||
input = 100 * torch.randn(*size, device=device)
|
||||
with DeterministicGuard(True):
|
||||
res0 = input.cumsum(dim)
|
||||
for _ in range(3):
|
||||
res1 = input.cumsum(dim)
|
||||
self.assertEqual(res0, res1, atol=0, rtol=0)
|
||||
|
||||
res_cpu = input.cpu().cumsum(dim)
|
||||
self.assertEqual(res0, res_cpu, atol=1e-3, rtol=1e-2)
|
||||
|
||||
|
||||
@dtypes(*all_types_and_complex_and(torch.bool))
|
||||
@skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
|
||||
def test_nondeterministic_alert_cumsum(self, device, dtype):
|
||||
input = make_tensor((10,), dtype=dtype, device=device, low=-9, high=9)
|
||||
should_alert = False
|
||||
should_alert = torch.device(device).type == 'cuda' and (dtype.is_floating_point or dtype.is_complex)
|
||||
|
||||
for op_call in [torch.Tensor.cumsum, torch.cumsum]:
|
||||
self.check_nondeterministic_alert(
|
||||
|
|
|
|||
|
|
@ -136,7 +136,6 @@ blocklist = [
|
|||
"requires_grad",
|
||||
"range",
|
||||
# defined in functional
|
||||
"cumsum",
|
||||
"einsum",
|
||||
# Somehow, these are defined in both _C and in functional. Ick!
|
||||
"broadcast_tensors",
|
||||
|
|
|
|||
|
|
@ -1235,7 +1235,6 @@ def use_deterministic_algorithms(
|
|||
* :func:`torch.Tensor.put_` with ``accumulate=True`` when called on a CPU
|
||||
tensor
|
||||
* :func:`torch.Tensor.scatter_add_` when called on a CUDA tensor
|
||||
* :func:`torch.cumsum` when called on a CUDA tensor
|
||||
* :func:`torch.gather` when called on a CUDA tensor that requires grad
|
||||
* :func:`torch.index_add` when called on CUDA tensor
|
||||
* :func:`torch.index_select` when attempting to differentiate a CUDA tensor
|
||||
|
|
@ -1282,6 +1281,7 @@ def use_deterministic_algorithms(
|
|||
* :func:`torch.kthvalue` with called on a CUDA tensor
|
||||
* :func:`torch.median` with indices output when called on a CUDA tensor
|
||||
* :func:`torch.nn.functional.grid_sample` when attempting to differentiate a CUDA tensor
|
||||
* :func:`torch.cumsum` when called on a CUDA tensor when dtype is floating point or complex
|
||||
* :func:`torch.Tensor.scatter_reduce` when ``reduce='prod'`` and called on CUDA tensor
|
||||
* :func:`torch.Tensor.resize_` when called with a quantized tensor
|
||||
|
||||
|
|
|
|||
|
|
@ -846,37 +846,6 @@ class Tensor(torch._C.TensorBase):
|
|||
|
||||
return _symeig(self, eigenvectors=eigenvectors)
|
||||
|
||||
def cumsum(
|
||||
self,
|
||||
dim=None,
|
||||
*,
|
||||
dtype=None,
|
||||
out=None,
|
||||
axis=None,
|
||||
):
|
||||
r"""
|
||||
cumsum(dim, dtype=None) -> Tensor
|
||||
|
||||
See :func:`torch.cumsum`
|
||||
"""
|
||||
if axis is not None and dim is not None:
|
||||
raise RuntimeError("expected either 'dim' or 'axis' to be given, not both")
|
||||
elif axis is not None:
|
||||
dim = axis
|
||||
if has_torch_function_unary(self):
|
||||
return handle_torch_function(
|
||||
Tensor.cumsum,
|
||||
(self,),
|
||||
self,
|
||||
dim,
|
||||
dtype=dtype,
|
||||
out=out,
|
||||
)
|
||||
if out is None:
|
||||
return torch.cumsum(self, dim, dtype=dtype)
|
||||
else:
|
||||
return torch.cumsum(self, dim, dtype=dtype, out=out)
|
||||
|
||||
def lu(self, pivot=True, get_infos=False):
|
||||
r"""See :func:`torch.lu`"""
|
||||
# If get_infos is True, then we don't need to check for errors and vice versa
|
||||
|
|
|
|||
|
|
@ -1497,6 +1497,15 @@ In-place version of :meth:`~Tensor.cumprod`
|
|||
""",
|
||||
)
|
||||
|
||||
add_docstr_all(
|
||||
"cumsum",
|
||||
r"""
|
||||
cumsum(dim, dtype=None) -> Tensor
|
||||
|
||||
See :func:`torch.cumsum`
|
||||
""",
|
||||
)
|
||||
|
||||
add_docstr_all(
|
||||
"cumsum_",
|
||||
r"""
|
||||
|
|
|
|||
|
|
@ -3317,6 +3317,38 @@ Example::
|
|||
""".format(**reduceops_common_args),
|
||||
)
|
||||
|
||||
add_docstr(
|
||||
torch.cumsum,
|
||||
r"""
|
||||
cumsum(input, dim, *, dtype=None, out=None) -> Tensor
|
||||
|
||||
Returns the cumulative sum of elements of :attr:`input` in the dimension
|
||||
:attr:`dim`.
|
||||
|
||||
For example, if :attr:`input` is a vector of size N, the result will also be
|
||||
a vector of size N, with elements.
|
||||
|
||||
.. math::
|
||||
y_i = x_1 + x_2 + x_3 + \dots + x_i
|
||||
|
||||
Args:
|
||||
{input}
|
||||
dim (int): the dimension to do the operation over
|
||||
|
||||
Keyword args:
|
||||
{dtype}
|
||||
{out}
|
||||
|
||||
Example::
|
||||
|
||||
>>> a = torch.randint(1, 20, (10,))
|
||||
>>> a
|
||||
tensor([13, 7, 3, 10, 13, 3, 15, 10, 9, 10])
|
||||
>>> torch.cumsum(a, dim=0)
|
||||
tensor([13, 20, 23, 33, 46, 49, 64, 74, 83, 93])
|
||||
""".format(**reduceops_common_args),
|
||||
)
|
||||
|
||||
add_docstr(
|
||||
torch.count_nonzero,
|
||||
r"""
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import importlib
|
||||
import itertools
|
||||
import operator
|
||||
from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
|
||||
|
|
@ -29,7 +28,6 @@ __all__ = [
|
|||
"block_diag",
|
||||
"cdist",
|
||||
"chain_matmul",
|
||||
"cumsum",
|
||||
"einsum",
|
||||
"istft",
|
||||
"lu",
|
||||
|
|
@ -2037,62 +2035,6 @@ def chain_matmul(*matrices, out=None):
|
|||
return _VF.chain_matmul(matrices, out=out) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def cumsum(
|
||||
self: Tensor,
|
||||
dim: Optional[int] = None,
|
||||
*,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
out: Optional[Tensor] = None,
|
||||
axis: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
cumsum(input, dim, *, dtype=None, out=None) -> Tensor
|
||||
|
||||
Returns the cumulative sum of elements of :attr:`input` in the dimension
|
||||
:attr:`dim`.
|
||||
|
||||
For example, if :attr:`input` is a vector of size N, the result will also be
|
||||
a vector of size N, with elements.
|
||||
|
||||
.. math::
|
||||
y_i = x_1 + x_2 + x_3 + \dots + x_i
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor.
|
||||
dim (int): the dimension to do the operation over
|
||||
|
||||
Keyword args:
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
|
||||
If specified, the input tensor is casted to :attr:`dtype` before the operation
|
||||
is performed. This is useful for preventing data type overflows. Default: None.
|
||||
out (Tensor, optional): the output tensor.
|
||||
|
||||
Example::
|
||||
|
||||
>>> torch.manual_seed(0)
|
||||
>>> a = torch.randint(1, 20, (10,))
|
||||
>>> a
|
||||
tensor([16, 5, 1, 1, 12, 8, 6, 10, 10, 5])
|
||||
>>> torch.cumsum(a, dim=0)
|
||||
tensor([16, 21, 22, 23, 35, 43, 49, 59, 69, 74])
|
||||
"""
|
||||
if axis is not None:
|
||||
if dim is None:
|
||||
dim = axis
|
||||
else:
|
||||
raise RuntimeError("expected either 'dim' or 'axis' to be given, not both")
|
||||
if has_torch_function_unary(self):
|
||||
return handle_torch_function(cumsum, (self,), self, dim, dtype=dtype, out=out)
|
||||
if not torch.jit.is_scripting():
|
||||
if torch.are_deterministic_algorithms_enabled() and self.is_cuda:
|
||||
ref_func = importlib.import_module("torch._refs").cumsum
|
||||
return ref_func(self, dim, dtype=dtype, out=out)
|
||||
if out is None:
|
||||
return _VF.cumsum(self, dim, dtype=dtype) # type: ignore[attr-defined]
|
||||
else:
|
||||
return _VF.cumsum(self, dim, dtype=dtype, out=out) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _lu_impl(A, pivot=True, get_infos=False, out=None):
|
||||
# type: (Tensor, bool, bool, Any) -> Tuple[Tensor, Tensor, Tensor]
|
||||
r"""Computes the LU factorization of a matrix or batches of matrices
|
||||
|
|
|
|||
|
|
@ -553,7 +553,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.cummax: lambda input, dim, out=None: -1,
|
||||
torch.cummin: lambda input, dim, out=None: -1,
|
||||
torch.cumprod: lambda input, dim, out=None, dtype=None: -1,
|
||||
torch.cumsum: lambda input, dim, out=None, dtype=None, axis=None: -1,
|
||||
torch.cumsum: lambda input, dim, out=None, dtype=None: -1,
|
||||
torch.cumulative_trapezoid: lambda y, x=None, dim=-1: -1,
|
||||
torch.logcumsumexp: lambda input, dim, out=None: -1,
|
||||
torch.deg2rad: lambda input, out=None: -1,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user