diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 6821df7171f..06dfd3292ba 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -6056,7 +6056,7 @@ - func: scatter_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor variants: function, method -- func: _scatter_reduce.two(Tensor self, int dim, Tensor index, str reduce, *, int? output_size=None) -> Tensor +- func: scatter_reduce.two(Tensor self, int dim, Tensor index, str reduce, *, int? output_size=None) -> Tensor variants: function, method dispatch: CPU: scatter_reduce_two_cpu diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 090824e0ee3..8f04298887b 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -593,6 +593,7 @@ Tensor class reference Tensor.scatter_ Tensor.scatter_add_ Tensor.scatter_add + Tensor.scatter_reduce Tensor.select Tensor.select_scatter Tensor.set_ diff --git a/docs/source/torch.rst b/docs/source/torch.rst index d3ae7a7151e..e09675af82a 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -118,6 +118,7 @@ Indexing, Slicing, Joining, Mutating Ops select_scatter slice_scatter scatter_add + scatter_reduce split squeeze stack diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 329e65b3281..ffe26acfe3b 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -104,10 +104,10 @@ ALLOW_LIST = [ ("aten::nanquantile", datetime.date(2022, 9, 30)), ("aten::_convolution_double_backward", datetime.date(2022, 3, 31)), ("aten::_scatter_reduce", datetime.date(2022, 1, 31)), - ("aten::scatter_reduce.two", datetime.date(2022, 3, 15)), ("aten::native_multi_head_self_attention", datetime.date(9999, 1, 1)), ("aten::_native_multi_head_self_attention", datetime.date(9999, 1, 1)), ("aten::_transform_bias_rescale_qkv", datetime.date(9999, 1, 1)), + ("aten::_scatter_reduce.two", datetime.date(9999, 1, 1)), ] ALLOW_LIST_COMPILED = [ diff --git a/test/test_torch.py b/test/test_torch.py index 16cf9e2e61f..ff8a4a1544a 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -5773,7 +5773,7 @@ class TestTorch(TestCase): for reduce in reduces: for dim in range(len(shape)): - output = input._scatter_reduce(dim, index, reduce, output_size=output_size) + output = input.scatter_reduce(dim, index, reduce, output_size=output_size) # Check that output is of the correct size output_shape = copy.copy(shape) @@ -5807,16 +5807,16 @@ class TestTorch(TestCase): self.assertTrue(torch.allclose(output, expected)) with self.assertRaisesRegex(RuntimeError, "Expected `dim` to be in range -3 to 2"): - torch._scatter_reduce(input, 4, index, "sum") + torch.scatter_reduce(input, 4, index, "sum") with self.assertRaisesRegex(RuntimeError, "Shape mismatch"): index2 = torch.randint(0, output_size, (10, ), dtype=torch.long, device=device) - torch._scatter_reduce(input, 0, index2, "sum") + torch.scatter_reduce(input, 0, index2, "sum") with self.assertRaisesRegex(RuntimeError, "Expected `index` values to be in range 0 to 2"): input2 = torch.randn(10, dtype=dtype, device=device) index2 = torch.tensor([0, 1, 0, 1, 2, 3, 3, 4, 4, 3]) - torch._scatter_reduce(input2, 0, index2, "sum", output_size=2) + torch.scatter_reduce(input2, 0, index2, "sum", output_size=2) def test_structseq_repr(self): a = torch.arange(250).reshape(5, 5, 10) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 27e4007d569..7f7c13f01aa 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2595,6 +2595,6 @@ - name: _efficientzerotensor(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor output_differentiability: [False] -- name: _scatter_reduce.two(Tensor self, int dim, Tensor index, str reduce, *, int? output_size=None) -> Tensor +- name: scatter_reduce.two(Tensor self, int dim, Tensor index, str reduce, *, int? output_size=None) -> Tensor self: scatter_reduce_backward(grad, self, dim, index, reduce, result) index: non_differentiable diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 66ffffec87b..7ff5da2c2f4 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -3374,6 +3374,12 @@ Example:: """.format(**reproducibility_notes)) +add_docstr_all('scatter_reduce', r""" +scatter_reduce(input, dim, index, reduce, *, output_size=None) -> Tensor + +See :func:`torch.scatter_reduce` +""") + add_docstr_all('select', r""" select(dim, index) -> Tensor diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index db65dd8cd98..4ba8d92b583 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -8547,6 +8547,59 @@ scatter_add(input, dim, index, src) -> Tensor Out-of-place version of :meth:`torch.Tensor.scatter_add_` """) +add_docstr(torch.scatter_reduce, r""" +scatter_reduce(input, dim, index, reduce, *, output_size=None) -> Tensor + +Reduces all values from the :attr:`input` tensor to the indices specified in +the :attr:`index` tensor. For each value in :attr:`input`, its output index is +specified by its index in :attr:`input` for ``dimension != dim`` and by the +corresponding value in :attr:`index` for ``dimension = dim``. +The applied reduction for non-unique indices is defined via the :attr:`reduce` +argument (:obj:`"sum"`, :obj:`"prod"`, :obj:`"mean"`, :obj:`"amax"`, :obj:`"amin"`). +For non-existing indices, the output will be filled with the identity of the +applied reduction (1 for :obj:`"prod"` and 0 otherwise). + +It is also required that ``index.size(d) == input.size(d)`` for all dimensions ``d``. +Moreover, if :attr:`output_size` is defined the the values of :attr:`index` must be +between ``0`` and ``output_size - 1`` inclusive. + + +For a 3-D tensor with :obj:`reduce="sum"`, the output is given as:: + + out[index[i][j][k]][j][k] += input[i][j][k] # if dim == 0 + out[i][index[i][j][k]][k] += input[i][j][k] # if dim == 1 + out[i][j][index[i][j][k]] += input[i][j][k] # if dim == 2 + +Note: + This out-of-place operation is similar to the in-place versions of + :meth:`~torch.Tensor.scatter_` and :meth:`~torch.Tensor.scatter_add_`, + in which the output tensor is automatically created according to the + maximum values in :attr:`index` and filled based on the identity of the + applied reduction. + +Note: + {forward_reproducibility_note} + +Args: + input (Tensor): the input tensor + dim (int): the axis along which to index + index (LongTensor): the indices of elements to scatter and reduce. + src (Tensor): the source elements to scatter and reduce + reduce (str): the reduction operation to apply for non-unique indices + (:obj:`"sum"`, :obj:`"prod"`, :obj:`"mean"`, :obj:`"amax"`, :obj:`"amin"`) + output_size (int, optional): the size of the output at dimension :attr:`dim`. + If set to :obj:`None`, will get automatically inferred according to + :obj:`index.max() + 1` + +Example:: + + >>> input = torch.tensor([1, 2, 3, 4, 5, 6]) + >>> index = torch.tensor([0, 1, 0, 1, 2, 1]) + >>> torch.scatter_reduce(input, 0, index, reduce="sum", output_size=3) + tensor([4, 12, 5]) + +""".format(**reproducibility_notes)) + add_docstr(torch.select, r""" select(input, dim, index) -> Tensor diff --git a/torch/overrides.py b/torch/overrides.py index 3b0b437fa16..1bbc49f5dfd 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -896,7 +896,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.saddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1, torch.scatter: lambda input, dim, index, src: -1, torch.scatter_add: lambda input, dim, index, src: -1, - torch._scatter_reduce: lambda input, dim, index, reduce, output_size=None: -1, + torch.scatter_reduce: lambda input, dim, index, reduce, output_size=None: -1, torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1, torch.segment_reduce: lambda data, reduce="max", lengths=None, indices=None, axis=0, unsafe=False: -1, torch.select: lambda input, dim, index: -1, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index e1f3a1f5032..82d0f427541 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -15537,7 +15537,7 @@ op_db: List[OpInfo] = [ supports_fwgrad_bwgrad=True, ), OpInfo( - '_scatter_reduce', + 'scatter_reduce', dtypes=all_types_and(torch.float16, torch.bfloat16), sample_inputs_func=sample_inputs_scatter_reduce, supports_out=False,