mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
scatter_reduce documentation (#73125)
Summary: Reland of https://github.com/pytorch/pytorch/issues/68580 (which were milestoned for 1.11) plus partial revert of https://github.com/pytorch/pytorch/pull/72543 Pull Request resolved: https://github.com/pytorch/pytorch/pull/73125 Reviewed By: bdhirsh Differential Revision: D34355217 Pulled By: malfet fbshipit-source-id: 325ecdeaf53183d653b44ee5e6e8839ceefd9200
This commit is contained in:
parent
c5265e90c7
commit
71db31748a
|
|
@ -6056,7 +6056,7 @@
|
|||
- func: scatter_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor
|
||||
variants: function, method
|
||||
|
||||
- func: _scatter_reduce.two(Tensor self, int dim, Tensor index, str reduce, *, int? output_size=None) -> Tensor
|
||||
- func: scatter_reduce.two(Tensor self, int dim, Tensor index, str reduce, *, int? output_size=None) -> Tensor
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CPU: scatter_reduce_two_cpu
|
||||
|
|
|
|||
|
|
@ -593,6 +593,7 @@ Tensor class reference
|
|||
Tensor.scatter_
|
||||
Tensor.scatter_add_
|
||||
Tensor.scatter_add
|
||||
Tensor.scatter_reduce
|
||||
Tensor.select
|
||||
Tensor.select_scatter
|
||||
Tensor.set_
|
||||
|
|
|
|||
|
|
@ -118,6 +118,7 @@ Indexing, Slicing, Joining, Mutating Ops
|
|||
select_scatter
|
||||
slice_scatter
|
||||
scatter_add
|
||||
scatter_reduce
|
||||
split
|
||||
squeeze
|
||||
stack
|
||||
|
|
|
|||
|
|
@ -104,10 +104,10 @@ ALLOW_LIST = [
|
|||
("aten::nanquantile", datetime.date(2022, 9, 30)),
|
||||
("aten::_convolution_double_backward", datetime.date(2022, 3, 31)),
|
||||
("aten::_scatter_reduce", datetime.date(2022, 1, 31)),
|
||||
("aten::scatter_reduce.two", datetime.date(2022, 3, 15)),
|
||||
("aten::native_multi_head_self_attention", datetime.date(9999, 1, 1)),
|
||||
("aten::_native_multi_head_self_attention", datetime.date(9999, 1, 1)),
|
||||
("aten::_transform_bias_rescale_qkv", datetime.date(9999, 1, 1)),
|
||||
("aten::_scatter_reduce.two", datetime.date(9999, 1, 1)),
|
||||
]
|
||||
|
||||
ALLOW_LIST_COMPILED = [
|
||||
|
|
|
|||
|
|
@ -5773,7 +5773,7 @@ class TestTorch(TestCase):
|
|||
|
||||
for reduce in reduces:
|
||||
for dim in range(len(shape)):
|
||||
output = input._scatter_reduce(dim, index, reduce, output_size=output_size)
|
||||
output = input.scatter_reduce(dim, index, reduce, output_size=output_size)
|
||||
|
||||
# Check that output is of the correct size
|
||||
output_shape = copy.copy(shape)
|
||||
|
|
@ -5807,16 +5807,16 @@ class TestTorch(TestCase):
|
|||
self.assertTrue(torch.allclose(output, expected))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected `dim` to be in range -3 to 2"):
|
||||
torch._scatter_reduce(input, 4, index, "sum")
|
||||
torch.scatter_reduce(input, 4, index, "sum")
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Shape mismatch"):
|
||||
index2 = torch.randint(0, output_size, (10, ), dtype=torch.long, device=device)
|
||||
torch._scatter_reduce(input, 0, index2, "sum")
|
||||
torch.scatter_reduce(input, 0, index2, "sum")
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected `index` values to be in range 0 to 2"):
|
||||
input2 = torch.randn(10, dtype=dtype, device=device)
|
||||
index2 = torch.tensor([0, 1, 0, 1, 2, 3, 3, 4, 4, 3])
|
||||
torch._scatter_reduce(input2, 0, index2, "sum", output_size=2)
|
||||
torch.scatter_reduce(input2, 0, index2, "sum", output_size=2)
|
||||
|
||||
def test_structseq_repr(self):
|
||||
a = torch.arange(250).reshape(5, 5, 10)
|
||||
|
|
|
|||
|
|
@ -2595,6 +2595,6 @@
|
|||
- name: _efficientzerotensor(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
output_differentiability: [False]
|
||||
|
||||
- name: _scatter_reduce.two(Tensor self, int dim, Tensor index, str reduce, *, int? output_size=None) -> Tensor
|
||||
- name: scatter_reduce.two(Tensor self, int dim, Tensor index, str reduce, *, int? output_size=None) -> Tensor
|
||||
self: scatter_reduce_backward(grad, self, dim, index, reduce, result)
|
||||
index: non_differentiable
|
||||
|
|
|
|||
|
|
@ -3374,6 +3374,12 @@ Example::
|
|||
|
||||
""".format(**reproducibility_notes))
|
||||
|
||||
add_docstr_all('scatter_reduce', r"""
|
||||
scatter_reduce(input, dim, index, reduce, *, output_size=None) -> Tensor
|
||||
|
||||
See :func:`torch.scatter_reduce`
|
||||
""")
|
||||
|
||||
add_docstr_all('select',
|
||||
r"""
|
||||
select(dim, index) -> Tensor
|
||||
|
|
|
|||
|
|
@ -8547,6 +8547,59 @@ scatter_add(input, dim, index, src) -> Tensor
|
|||
Out-of-place version of :meth:`torch.Tensor.scatter_add_`
|
||||
""")
|
||||
|
||||
add_docstr(torch.scatter_reduce, r"""
|
||||
scatter_reduce(input, dim, index, reduce, *, output_size=None) -> Tensor
|
||||
|
||||
Reduces all values from the :attr:`input` tensor to the indices specified in
|
||||
the :attr:`index` tensor. For each value in :attr:`input`, its output index is
|
||||
specified by its index in :attr:`input` for ``dimension != dim`` and by the
|
||||
corresponding value in :attr:`index` for ``dimension = dim``.
|
||||
The applied reduction for non-unique indices is defined via the :attr:`reduce`
|
||||
argument (:obj:`"sum"`, :obj:`"prod"`, :obj:`"mean"`, :obj:`"amax"`, :obj:`"amin"`).
|
||||
For non-existing indices, the output will be filled with the identity of the
|
||||
applied reduction (1 for :obj:`"prod"` and 0 otherwise).
|
||||
|
||||
It is also required that ``index.size(d) == input.size(d)`` for all dimensions ``d``.
|
||||
Moreover, if :attr:`output_size` is defined the the values of :attr:`index` must be
|
||||
between ``0`` and ``output_size - 1`` inclusive.
|
||||
|
||||
|
||||
For a 3-D tensor with :obj:`reduce="sum"`, the output is given as::
|
||||
|
||||
out[index[i][j][k]][j][k] += input[i][j][k] # if dim == 0
|
||||
out[i][index[i][j][k]][k] += input[i][j][k] # if dim == 1
|
||||
out[i][j][index[i][j][k]] += input[i][j][k] # if dim == 2
|
||||
|
||||
Note:
|
||||
This out-of-place operation is similar to the in-place versions of
|
||||
:meth:`~torch.Tensor.scatter_` and :meth:`~torch.Tensor.scatter_add_`,
|
||||
in which the output tensor is automatically created according to the
|
||||
maximum values in :attr:`index` and filled based on the identity of the
|
||||
applied reduction.
|
||||
|
||||
Note:
|
||||
{forward_reproducibility_note}
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
dim (int): the axis along which to index
|
||||
index (LongTensor): the indices of elements to scatter and reduce.
|
||||
src (Tensor): the source elements to scatter and reduce
|
||||
reduce (str): the reduction operation to apply for non-unique indices
|
||||
(:obj:`"sum"`, :obj:`"prod"`, :obj:`"mean"`, :obj:`"amax"`, :obj:`"amin"`)
|
||||
output_size (int, optional): the size of the output at dimension :attr:`dim`.
|
||||
If set to :obj:`None`, will get automatically inferred according to
|
||||
:obj:`index.max() + 1`
|
||||
|
||||
Example::
|
||||
|
||||
>>> input = torch.tensor([1, 2, 3, 4, 5, 6])
|
||||
>>> index = torch.tensor([0, 1, 0, 1, 2, 1])
|
||||
>>> torch.scatter_reduce(input, 0, index, reduce="sum", output_size=3)
|
||||
tensor([4, 12, 5])
|
||||
|
||||
""".format(**reproducibility_notes))
|
||||
|
||||
add_docstr(torch.select,
|
||||
r"""
|
||||
select(input, dim, index) -> Tensor
|
||||
|
|
|
|||
|
|
@ -896,7 +896,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.saddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
|
||||
torch.scatter: lambda input, dim, index, src: -1,
|
||||
torch.scatter_add: lambda input, dim, index, src: -1,
|
||||
torch._scatter_reduce: lambda input, dim, index, reduce, output_size=None: -1,
|
||||
torch.scatter_reduce: lambda input, dim, index, reduce, output_size=None: -1,
|
||||
torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1,
|
||||
torch.segment_reduce: lambda data, reduce="max", lengths=None, indices=None, axis=0, unsafe=False: -1,
|
||||
torch.select: lambda input, dim, index: -1,
|
||||
|
|
|
|||
|
|
@ -15537,7 +15537,7 @@ op_db: List[OpInfo] = [
|
|||
supports_fwgrad_bwgrad=True,
|
||||
),
|
||||
OpInfo(
|
||||
'_scatter_reduce',
|
||||
'scatter_reduce',
|
||||
dtypes=all_types_and(torch.float16, torch.bfloat16),
|
||||
sample_inputs_func=sample_inputs_scatter_reduce,
|
||||
supports_out=False,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user