scatter_reduce documentation (#73125)

Summary:
Reland of https://github.com/pytorch/pytorch/issues/68580 (which were milestoned for 1.11) plus partial revert of https://github.com/pytorch/pytorch/pull/72543

Pull Request resolved: https://github.com/pytorch/pytorch/pull/73125

Reviewed By: bdhirsh

Differential Revision: D34355217

Pulled By: malfet

fbshipit-source-id: 325ecdeaf53183d653b44ee5e6e8839ceefd9200
(cherry picked from commit 71db31748a)
This commit is contained in:
Nikita Shulga 2022-02-22 11:19:49 -08:00 committed by PyTorch MergeBot
parent e12c57a35b
commit cfb6c942fe
10 changed files with 70 additions and 9 deletions

View File

@ -6056,7 +6056,7 @@
- func: scatter_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor - func: scatter_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor
variants: function, method variants: function, method
- func: _scatter_reduce.two(Tensor self, int dim, Tensor index, str reduce, *, int? output_size=None) -> Tensor - func: scatter_reduce.two(Tensor self, int dim, Tensor index, str reduce, *, int? output_size=None) -> Tensor
variants: function, method variants: function, method
dispatch: dispatch:
CPU: scatter_reduce_two_cpu CPU: scatter_reduce_two_cpu

View File

@ -593,6 +593,7 @@ Tensor class reference
Tensor.scatter_ Tensor.scatter_
Tensor.scatter_add_ Tensor.scatter_add_
Tensor.scatter_add Tensor.scatter_add
Tensor.scatter_reduce
Tensor.select Tensor.select
Tensor.select_scatter Tensor.select_scatter
Tensor.set_ Tensor.set_

View File

@ -118,6 +118,7 @@ Indexing, Slicing, Joining, Mutating Ops
select_scatter select_scatter
slice_scatter slice_scatter
scatter_add scatter_add
scatter_reduce
split split
squeeze squeeze
stack stack

View File

@ -104,10 +104,10 @@ ALLOW_LIST = [
("aten::nanquantile", datetime.date(2022, 9, 30)), ("aten::nanquantile", datetime.date(2022, 9, 30)),
("aten::_convolution_double_backward", datetime.date(2022, 3, 31)), ("aten::_convolution_double_backward", datetime.date(2022, 3, 31)),
("aten::_scatter_reduce", datetime.date(2022, 1, 31)), ("aten::_scatter_reduce", datetime.date(2022, 1, 31)),
("aten::scatter_reduce.two", datetime.date(2022, 3, 15)),
("aten::native_multi_head_self_attention", datetime.date(9999, 1, 1)), ("aten::native_multi_head_self_attention", datetime.date(9999, 1, 1)),
("aten::_native_multi_head_self_attention", datetime.date(9999, 1, 1)), ("aten::_native_multi_head_self_attention", datetime.date(9999, 1, 1)),
("aten::_transform_bias_rescale_qkv", datetime.date(9999, 1, 1)), ("aten::_transform_bias_rescale_qkv", datetime.date(9999, 1, 1)),
("aten::_scatter_reduce.two", datetime.date(9999, 1, 1)),
] ]
ALLOW_LIST_COMPILED = [ ALLOW_LIST_COMPILED = [

View File

@ -5773,7 +5773,7 @@ class TestTorch(TestCase):
for reduce in reduces: for reduce in reduces:
for dim in range(len(shape)): for dim in range(len(shape)):
output = input._scatter_reduce(dim, index, reduce, output_size=output_size) output = input.scatter_reduce(dim, index, reduce, output_size=output_size)
# Check that output is of the correct size # Check that output is of the correct size
output_shape = copy.copy(shape) output_shape = copy.copy(shape)
@ -5807,16 +5807,16 @@ class TestTorch(TestCase):
self.assertTrue(torch.allclose(output, expected)) self.assertTrue(torch.allclose(output, expected))
with self.assertRaisesRegex(RuntimeError, "Expected `dim` to be in range -3 to 2"): with self.assertRaisesRegex(RuntimeError, "Expected `dim` to be in range -3 to 2"):
torch._scatter_reduce(input, 4, index, "sum") torch.scatter_reduce(input, 4, index, "sum")
with self.assertRaisesRegex(RuntimeError, "Shape mismatch"): with self.assertRaisesRegex(RuntimeError, "Shape mismatch"):
index2 = torch.randint(0, output_size, (10, ), dtype=torch.long, device=device) index2 = torch.randint(0, output_size, (10, ), dtype=torch.long, device=device)
torch._scatter_reduce(input, 0, index2, "sum") torch.scatter_reduce(input, 0, index2, "sum")
with self.assertRaisesRegex(RuntimeError, "Expected `index` values to be in range 0 to 2"): with self.assertRaisesRegex(RuntimeError, "Expected `index` values to be in range 0 to 2"):
input2 = torch.randn(10, dtype=dtype, device=device) input2 = torch.randn(10, dtype=dtype, device=device)
index2 = torch.tensor([0, 1, 0, 1, 2, 3, 3, 4, 4, 3]) index2 = torch.tensor([0, 1, 0, 1, 2, 3, 3, 4, 4, 3])
torch._scatter_reduce(input2, 0, index2, "sum", output_size=2) torch.scatter_reduce(input2, 0, index2, "sum", output_size=2)
def test_structseq_repr(self): def test_structseq_repr(self):
a = torch.arange(250).reshape(5, 5, 10) a = torch.arange(250).reshape(5, 5, 10)

View File

@ -2595,6 +2595,6 @@
- name: _efficientzerotensor(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - name: _efficientzerotensor(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
output_differentiability: [False] output_differentiability: [False]
- name: _scatter_reduce.two(Tensor self, int dim, Tensor index, str reduce, *, int? output_size=None) -> Tensor - name: scatter_reduce.two(Tensor self, int dim, Tensor index, str reduce, *, int? output_size=None) -> Tensor
self: scatter_reduce_backward(grad, self, dim, index, reduce, result) self: scatter_reduce_backward(grad, self, dim, index, reduce, result)
index: non_differentiable index: non_differentiable

View File

@ -3374,6 +3374,12 @@ Example::
""".format(**reproducibility_notes)) """.format(**reproducibility_notes))
add_docstr_all('scatter_reduce', r"""
scatter_reduce(input, dim, index, reduce, *, output_size=None) -> Tensor
See :func:`torch.scatter_reduce`
""")
add_docstr_all('select', add_docstr_all('select',
r""" r"""
select(dim, index) -> Tensor select(dim, index) -> Tensor

View File

@ -8547,6 +8547,59 @@ scatter_add(input, dim, index, src) -> Tensor
Out-of-place version of :meth:`torch.Tensor.scatter_add_` Out-of-place version of :meth:`torch.Tensor.scatter_add_`
""") """)
add_docstr(torch.scatter_reduce, r"""
scatter_reduce(input, dim, index, reduce, *, output_size=None) -> Tensor
Reduces all values from the :attr:`input` tensor to the indices specified in
the :attr:`index` tensor. For each value in :attr:`input`, its output index is
specified by its index in :attr:`input` for ``dimension != dim`` and by the
corresponding value in :attr:`index` for ``dimension = dim``.
The applied reduction for non-unique indices is defined via the :attr:`reduce`
argument (:obj:`"sum"`, :obj:`"prod"`, :obj:`"mean"`, :obj:`"amax"`, :obj:`"amin"`).
For non-existing indices, the output will be filled with the identity of the
applied reduction (1 for :obj:`"prod"` and 0 otherwise).
It is also required that ``index.size(d) == input.size(d)`` for all dimensions ``d``.
Moreover, if :attr:`output_size` is defined the the values of :attr:`index` must be
between ``0`` and ``output_size - 1`` inclusive.
For a 3-D tensor with :obj:`reduce="sum"`, the output is given as::
out[index[i][j][k]][j][k] += input[i][j][k] # if dim == 0
out[i][index[i][j][k]][k] += input[i][j][k] # if dim == 1
out[i][j][index[i][j][k]] += input[i][j][k] # if dim == 2
Note:
This out-of-place operation is similar to the in-place versions of
:meth:`~torch.Tensor.scatter_` and :meth:`~torch.Tensor.scatter_add_`,
in which the output tensor is automatically created according to the
maximum values in :attr:`index` and filled based on the identity of the
applied reduction.
Note:
{forward_reproducibility_note}
Args:
input (Tensor): the input tensor
dim (int): the axis along which to index
index (LongTensor): the indices of elements to scatter and reduce.
src (Tensor): the source elements to scatter and reduce
reduce (str): the reduction operation to apply for non-unique indices
(:obj:`"sum"`, :obj:`"prod"`, :obj:`"mean"`, :obj:`"amax"`, :obj:`"amin"`)
output_size (int, optional): the size of the output at dimension :attr:`dim`.
If set to :obj:`None`, will get automatically inferred according to
:obj:`index.max() + 1`
Example::
>>> input = torch.tensor([1, 2, 3, 4, 5, 6])
>>> index = torch.tensor([0, 1, 0, 1, 2, 1])
>>> torch.scatter_reduce(input, 0, index, reduce="sum", output_size=3)
tensor([4, 12, 5])
""".format(**reproducibility_notes))
add_docstr(torch.select, add_docstr(torch.select,
r""" r"""
select(input, dim, index) -> Tensor select(input, dim, index) -> Tensor

View File

@ -896,7 +896,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.saddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1, torch.saddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
torch.scatter: lambda input, dim, index, src: -1, torch.scatter: lambda input, dim, index, src: -1,
torch.scatter_add: lambda input, dim, index, src: -1, torch.scatter_add: lambda input, dim, index, src: -1,
torch._scatter_reduce: lambda input, dim, index, reduce, output_size=None: -1, torch.scatter_reduce: lambda input, dim, index, reduce, output_size=None: -1,
torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1, torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1,
torch.segment_reduce: lambda data, reduce="max", lengths=None, indices=None, axis=0, unsafe=False: -1, torch.segment_reduce: lambda data, reduce="max", lengths=None, indices=None, axis=0, unsafe=False: -1,
torch.select: lambda input, dim, index: -1, torch.select: lambda input, dim, index: -1,

View File

@ -15537,7 +15537,7 @@ op_db: List[OpInfo] = [
supports_fwgrad_bwgrad=True, supports_fwgrad_bwgrad=True,
), ),
OpInfo( OpInfo(
'_scatter_reduce', 'scatter_reduce',
dtypes=all_types_and(torch.float16, torch.bfloat16), dtypes=all_types_and(torch.float16, torch.bfloat16),
sample_inputs_func=sample_inputs_scatter_reduce, sample_inputs_func=sample_inputs_scatter_reduce,
supports_out=False, supports_out=False,