mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
add slice/select/diagonal_scatter variants as primitive ops (#64430)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64430 The functionalization pass needs `{view}_scatter` versions of the slice/select/diagonal ops in order to correctly propagate mutations from a view to its base. On top of that, the implementations need to be primitive w.r.t. autograd, because they look something like `...slice().copy_()`, and the functionalization pass can't use views + mutations inside of it's own alias-removal machinery! I added some basic tests that I tried to base off of existing tests for views (particularly around testing the derivative formulas), but I'm wondering if I should add something more comprehensive. Also, as_strided fits into this category - the functionalization pass will need an `as_strided_scatter` op that's primitive w.r.t. autograd. I didn't add it for now, because it'll involve duplicating a bunch of logic from the current `as_strided_backward()` function, and also writing a derivative formula that I wasn't sure how to write :) Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D31942092 Pulled By: bdhirsh fbshipit-source-id: c702a57c2748a7c771c14e4bcc3e996b48fcc4c8
This commit is contained in:
parent
665c148e42
commit
03f3a0331b
|
|
@ -2561,4 +2561,27 @@ std::vector<Tensor> unflatten_dense_tensors(const Tensor& flat, TensorList tenso
|
|||
return outputs;
|
||||
}
|
||||
|
||||
}} // at::native
|
||||
at::Tensor slice_scatter(const at::Tensor& self, const at::Tensor& src, int64_t dim, c10::optional<int64_t> start, c10::optional<int64_t> end, int64_t step) {
|
||||
auto output = self.clone();
|
||||
auto slice = output.slice(dim, start, end, step);
|
||||
TORCH_CHECK(slice.sizes() == src.sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sizes(), ", slice size = ", slice.sizes());
|
||||
slice.copy_(src);
|
||||
return output;
|
||||
}
|
||||
at::Tensor select_scatter(const at::Tensor& self, const at::Tensor& src, int64_t dim, int64_t index) {
|
||||
auto output = self.clone();
|
||||
auto slice = output.select(dim, index);
|
||||
TORCH_CHECK(slice.sizes() == src.sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sizes(), ", slice size = ", slice.sizes());
|
||||
slice.copy_(src);
|
||||
return output;
|
||||
}
|
||||
at::Tensor diagonal_scatter(const at::Tensor& self, const at::Tensor& src, int64_t offset, int64_t dim1, int64_t dim2) {
|
||||
auto output = self.clone();
|
||||
auto slice = output.diagonal(offset, dim1, dim2);
|
||||
TORCH_CHECK(slice.sizes() == src.sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sizes(), ", slice size = ", slice.sizes());
|
||||
slice.copy_(src);
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
|
|
|||
|
|
@ -3910,6 +3910,27 @@
|
|||
dispatch:
|
||||
CompositeExplicitAutograd: slice_backward
|
||||
|
||||
- func: slice_scatter(Tensor self, Tensor src, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor
|
||||
variants: function, method
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: slice_scatter
|
||||
|
||||
- func: select_scatter(Tensor self, Tensor src, int dim, int index) -> Tensor
|
||||
variants: function, method
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: select_scatter
|
||||
|
||||
- func: diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) -> Tensor
|
||||
variants: function, method
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: diagonal_scatter
|
||||
|
||||
- func: slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet)
|
||||
variants: function, method
|
||||
dispatch:
|
||||
|
|
|
|||
|
|
@ -1266,6 +1266,20 @@
|
|||
|
||||
- name: slice_backward(Tensor grad_output, int[] input_sizes, int dim, int start, int end, int step) -> Tensor
|
||||
grad_output: grad.slice(dim, start, end, step)
|
||||
|
||||
- name: slice_scatter(Tensor self, Tensor src, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor
|
||||
self: slice_scatter(grad, zeros_like(src), dim, start, end, step)
|
||||
src: grad.slice(dim, start, end, step)
|
||||
result: auto_linear
|
||||
|
||||
- name: select_scatter(Tensor self, Tensor src, int dim, int index) -> Tensor
|
||||
self: select_scatter(grad, zeros_like(src), dim, index)
|
||||
src: grad.select(dim, index)
|
||||
result: auto_linear
|
||||
|
||||
- name: diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) -> Tensor
|
||||
self: diagonal_scatter(grad, zeros_like(src), offset, dim1, dim2)
|
||||
src: grad.diagonal(offset, dim1, dim2)
|
||||
result: auto_linear
|
||||
|
||||
- name: slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet)
|
||||
|
|
|
|||
|
|
@ -1169,6 +1169,13 @@ diagonal(offset=0, dim1=0, dim2=1) -> Tensor
|
|||
See :func:`torch.diagonal`
|
||||
""")
|
||||
|
||||
add_docstr_all('diagonal_scatter',
|
||||
r"""
|
||||
diagonal(src, offset=0, dim1=0, dim2=1) -> Tensor
|
||||
|
||||
See :func:`torch.diagonal_scatter`
|
||||
""")
|
||||
|
||||
add_docstr_all('fill_diagonal_',
|
||||
r"""
|
||||
fill_diagonal_(fill_value, wrap=False) -> Tensor
|
||||
|
|
@ -3352,18 +3359,21 @@ add_docstr_all('select',
|
|||
r"""
|
||||
select(dim, index) -> Tensor
|
||||
|
||||
Slices the :attr:`self` tensor along the selected dimension at the given index.
|
||||
This function returns a view of the original tensor with the given dimension removed.
|
||||
See :func:`torch.select`
|
||||
""")
|
||||
|
||||
Args:
|
||||
dim (int): the dimension to slice
|
||||
index (int): the index to select with
|
||||
add_docstr_all('select_scatter',
|
||||
r"""
|
||||
select_scatter(src, dim, index) -> Tensor
|
||||
|
||||
.. note::
|
||||
See :func:`torch.select_scatter`
|
||||
""")
|
||||
|
||||
:meth:`select` is equivalent to slicing. For example,
|
||||
``tensor.select(0, index)`` is equivalent to ``tensor[index]`` and
|
||||
``tensor.select(2, index)`` is equivalent to ``tensor[:,:,index]``.
|
||||
add_docstr_all('slice_scatter',
|
||||
r"""
|
||||
slice_scatter(src, dim=0, start=None, end=None, step=1) -> Tensor
|
||||
|
||||
See :func:`torch.slice_scatter`
|
||||
""")
|
||||
|
||||
add_docstr_all('set_',
|
||||
|
|
|
|||
|
|
@ -3156,6 +3156,58 @@ Examples::
|
|||
[ 1.0500, 0.7336, -0.3836, -1.1015]]])
|
||||
""".format(**common_args))
|
||||
|
||||
add_docstr(torch.diagonal_scatter,
|
||||
r"""
|
||||
diagonal_scatter(input, src, offset=0, dim1=0, dim2=1) -> Tensor
|
||||
|
||||
Embeds the values of the :attr:`src` tensor into :attr:`input` along
|
||||
the diagonal elements of :attr:`input`, with respect to :attr:`dim1`
|
||||
and :attr:`dim2`.
|
||||
|
||||
This function returns a tensor with fresh storage; it does not
|
||||
return a view.
|
||||
|
||||
The argument :attr:`offset` controls which diagonal to consider:
|
||||
|
||||
- If :attr:`offset` = 0, it is the main diagonal.
|
||||
- If :attr:`offset` > 0, it is above the main diagonal.
|
||||
- If :attr:`offset` < 0, it is below the main diagonal.
|
||||
|
||||
Args:
|
||||
{input} Must be at least 2-dimensional.
|
||||
src (Tensor): the tensor to embed into :attr:`input`.
|
||||
offset (int, optional): which diagonal to consider. Default: 0
|
||||
(main diagonal).
|
||||
dim1 (int, optional): first dimension with respect to which to
|
||||
take diagonal. Default: 0.
|
||||
dim2 (int, optional): second dimension with respect to which to
|
||||
take diagonal. Default: 1.
|
||||
|
||||
.. note::
|
||||
|
||||
:attr:`src` must be of the proper size in order to be embedded
|
||||
into :attr:`input`. Specifically, it should have the same shape as
|
||||
``torch.diagonal(input, offset, dim1, dim2)``
|
||||
|
||||
Examples::
|
||||
|
||||
>>> a = torch.zeros(3, 3)
|
||||
>>> a
|
||||
tensor([[0., 0., 0.],
|
||||
[0., 0., 0.],
|
||||
[0., 0., 0.]])
|
||||
|
||||
>>> torch.diagonal_scatter(a, torch.ones(3), 0)
|
||||
tensor([[1., 0., 0.],
|
||||
[0., 1., 0.],
|
||||
[0., 0., 1.]])
|
||||
|
||||
>>> torch.diagonal_scatter(a, torch.ones(2), 1)
|
||||
tensor([[0., 1., 0.],
|
||||
[0., 0., 1.],
|
||||
[0., 0., 0.]])
|
||||
""".format(**common_args))
|
||||
|
||||
add_docstr(torch.diff, r"""
|
||||
diff(input, n=1, dim=-1, prepend=None, append=None) -> Tensor
|
||||
|
||||
|
|
@ -8354,6 +8406,97 @@ scatter_add(input, dim, index, src) -> Tensor
|
|||
Out-of-place version of :meth:`torch.Tensor.scatter_add_`
|
||||
""")
|
||||
|
||||
add_docstr(torch.select,
|
||||
r"""
|
||||
select(input, dim, index) -> Tensor
|
||||
|
||||
Slices the :attr:`input` tensor along the selected dimension at the given index.
|
||||
This function returns a view of the original tensor with the given dimension removed.
|
||||
|
||||
Args:
|
||||
{input} (Tensor)
|
||||
dim (int): the dimension to slice
|
||||
index (int): the index to select with
|
||||
|
||||
.. note::
|
||||
|
||||
:meth:`select` is equivalent to slicing. For example,
|
||||
``tensor.select(0, index)`` is equivalent to ``tensor[index]`` and
|
||||
``tensor.select(2, index)`` is equivalent to ``tensor[:,:,index]``.
|
||||
""")
|
||||
|
||||
add_docstr(torch.select_scatter,
|
||||
r"""
|
||||
select_scatter(input, src, dim, index) -> Tensor
|
||||
|
||||
Embeds the values of the :attr:`src` tensor into :attr:`input` at the given index.
|
||||
This function returns a tensor with fresh storage; it does not create a view.
|
||||
|
||||
|
||||
Args:
|
||||
{input} (Tensor)
|
||||
src (Tensor): The tensor to embed into :attr:`input`
|
||||
dim (int): the dimension to insert the slice into.
|
||||
index (int): the index to select with
|
||||
|
||||
.. note::
|
||||
|
||||
:attr:`src` must be of the proper size in order to be embedded
|
||||
into :attr:`input`. Specifically, it should have the same shape as
|
||||
``torch.select(input, dim, index)``
|
||||
|
||||
Example::
|
||||
|
||||
>>> a = torch.zeros(2, 2)
|
||||
>>> b = torch.ones(2)
|
||||
>>> a.select_scatter(b, 0, 0)
|
||||
tensor([[1., 1.],
|
||||
[0., 0.]])
|
||||
""")
|
||||
|
||||
add_docstr(torch.slice_scatter,
|
||||
r"""
|
||||
slice_scatter(input, src, dim=0, start=None, end=None, step=1) -> Tensor
|
||||
|
||||
Embeds the values of the :attr:`src` tensor into :attr:`input` at the given
|
||||
dimension.
|
||||
This function returns a tensor with fresh storage; it does not create a view.
|
||||
|
||||
|
||||
Args:
|
||||
{input} (Tensor)
|
||||
src (Tensor): The tensor to embed into :attr:`input`
|
||||
dim (int): the dimension to insert the slice into
|
||||
start (Optional[int]): the start index of where to insert the slice
|
||||
end (Optional[int]): the end index of where to insert the slice
|
||||
step (int): the how many elements to skip in
|
||||
|
||||
Example::
|
||||
|
||||
>>> a = torch.zeros(8, 8)
|
||||
>>> b = torch.ones(8)
|
||||
>>> a.slice_scatter(b, start=6)
|
||||
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0., 0.],
|
||||
[1., 1., 1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 1., 1., 1., 1., 1.]])
|
||||
|
||||
>>> b = torch.ones(2)
|
||||
>>> a.slice_scatter(b, dim=1, start=2, end=6, step=2)
|
||||
tensor([[0., 0., 1., 0., 1., 0., 0., 0.],
|
||||
[0., 0., 1., 0., 1., 0., 0., 0.],
|
||||
[0., 0., 1., 0., 1., 0., 0., 0.],
|
||||
[0., 0., 1., 0., 1., 0., 0., 0.],
|
||||
[0., 0., 1., 0., 1., 0., 0., 0.],
|
||||
[0., 0., 1., 0., 1., 0., 0., 0.],
|
||||
[0., 0., 1., 0., 1., 0., 0., 0.],
|
||||
[0., 0., 1., 0., 1., 0., 0., 0.]])
|
||||
""")
|
||||
|
||||
add_docstr(torch.set_flush_denormal,
|
||||
r"""
|
||||
set_flush_denormal(mode) -> bool
|
||||
|
|
|
|||
|
|
@ -429,6 +429,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.diagflat: lambda input, offset=0: -1,
|
||||
torch.diff: lambda input, n=1, dim=-1, prepend=None, append=None, out=None: -1,
|
||||
torch.diagonal: lambda input, offset=0, dim1=0, dim2=1: -1,
|
||||
torch.diagonal_scatter: lambda input, src, offset=0, dim1=0, dim2=1: -1,
|
||||
torch.digamma: lambda input, out=None: -1,
|
||||
torch.dist: lambda input, other, p=2: -1,
|
||||
torch.div: lambda input, other, rounding_mode=None, out=None: -1,
|
||||
|
|
@ -885,6 +886,8 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1,
|
||||
torch.segment_reduce: lambda data, reduce="max", lengths=None, indices=None, axis=0, unsafe=False: -1,
|
||||
torch.select: lambda input, dim, index: -1,
|
||||
torch.select_scatter: lambda input, src, dim, index: -1,
|
||||
torch.slice_scatter: lambda input, src, dim, start, end, step: -1,
|
||||
torch.selu: lambda input, inplace=False: -1,
|
||||
torch.sigmoid: lambda input, out=None: -1,
|
||||
torch.sign: lambda input, out=None: -1,
|
||||
|
|
@ -1102,6 +1105,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
Tensor.xpu: lambda self, memory_format=torch.preserve_format: -1,
|
||||
Tensor.data_ptr: lambda self: -1,
|
||||
Tensor.dense_dim: lambda self: -1,
|
||||
Tensor.diagonal_scatter: lambda self, src, offset=0, dim1=0, dim2=1: -1,
|
||||
Tensor.dim: lambda self: -1,
|
||||
Tensor.double: lambda self, memory_format=torch.preserve_format: -1,
|
||||
Tensor.cdouble: lambda self, memory_format=torch.preserve_format: -1,
|
||||
|
|
@ -1154,9 +1158,11 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
Tensor.resize_as: lambda self, other: -1,
|
||||
Tensor.retain_grad: lambda self: -1,
|
||||
Tensor.set_: lambda self, source=None, storage_offset=0, size=None, stride=None: -1,
|
||||
Tensor.select_scatter: lambda self, src, dim, index: -1,
|
||||
Tensor.share_memory_: lambda self: -1,
|
||||
Tensor.short: lambda self, memory_format=torch.preserve_format: -1,
|
||||
Tensor.size: lambda self: -1,
|
||||
Tensor.slice_scatter: lambda self, src, dim, start, end, step: -1,
|
||||
Tensor.sparse_dim: lambda self: -1,
|
||||
Tensor.sparse_mask: lambda self, mask: -1,
|
||||
Tensor.sparse_resize_: lambda self, size1, size2, dense_dim: -1,
|
||||
|
|
|
|||
|
|
@ -4927,6 +4927,34 @@ def sample_inputs_diagonal_diag_embed(op_info, device, dtype, requires_grad, **k
|
|||
return list(generator())
|
||||
|
||||
|
||||
def sample_inputs_diagonal_scatter(op_info, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
|
||||
|
||||
# Shapes for 2D Tensors
|
||||
shapes_2d = ((M, M), (3, 5), (5, 3))
|
||||
|
||||
# Shapes for 3D Tensors
|
||||
shapes_3d = ((M, M, M),)
|
||||
|
||||
args_2d = ((), (2,), (-2,), (1,))
|
||||
args_3d = ((1, 1, 2), (2, 0, 1), (-2, 0, 1))
|
||||
|
||||
def generator():
|
||||
for input_shape, arg in chain(product(shapes_2d, args_2d), product(shapes_3d, args_3d)):
|
||||
input_ = make_arg(input_shape)
|
||||
# We can programatically figure out the right shape for src:
|
||||
# It should be the same size as input.diagonal(other_args...)
|
||||
if not isinstance(arg, tuple):
|
||||
arg_tuple = (arg,)
|
||||
else:
|
||||
arg_tuple = arg
|
||||
src_shape = input_.diagonal(*arg_tuple).size()
|
||||
src = make_arg(src_shape)
|
||||
yield SampleInput(input_, args=(src, *arg_tuple))
|
||||
|
||||
return list(generator())
|
||||
|
||||
|
||||
def sample_inputs_to_sparse(op_info, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
|
|
@ -5766,6 +5794,48 @@ def sample_inputs_select(op_info, device, dtype, requires_grad, **kwargs):
|
|||
return list(generator())
|
||||
|
||||
|
||||
def sample_inputs_select_scatter(op_info, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
|
||||
|
||||
cases = (((S, S, S), (S, S), (1, 2)),
|
||||
((S, S, S), (S, S), (-1, 2)),
|
||||
((S, S, S), (S, S), (-1, -1)),
|
||||
((S, S, S), (S, S), (1, -1)),
|
||||
((S,), (), (0, 2))
|
||||
)
|
||||
|
||||
def generator():
|
||||
for input_shape, src_shape, args in cases:
|
||||
input_ = make_arg(input_shape)
|
||||
src = make_arg(src_shape)
|
||||
yield SampleInput(input_, args=(src, *args))
|
||||
|
||||
return list(generator())
|
||||
|
||||
|
||||
def sample_inputs_slice_scatter(op_info, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
|
||||
|
||||
cases = (((L, L, L), (L, L, L,), (0, 0, L, 1)),
|
||||
((L, L, L), (L // 2, L, L,), (0, L // 2, L, 1)),
|
||||
((L, L, L), (L // 4, L, L,), (0, L // 2, L, 2)),
|
||||
((L, L, L), (L, L, L,), (1, 0, L, 1)),
|
||||
((L, L, L), (L, L // 2, L,), (1, L // 2, L, 1)),
|
||||
((L, L, L), (L, L // 4, L,), (1, L // 2, L, 2)),
|
||||
((L, L, L), (L, L, L,), (2, 0, L, 1)),
|
||||
((L, L, L), (L, L, L // 2,), (2, L // 2, L, 1)),
|
||||
((L, L, L), (L, L, L // 4,), (2, L // 2, L, 2)),
|
||||
)
|
||||
|
||||
def generator():
|
||||
for input_shape, src_shape, args in cases:
|
||||
input_ = make_arg(input_shape)
|
||||
src = make_arg(src_shape)
|
||||
yield SampleInput(input_, args=(src, *args))
|
||||
|
||||
return list(generator())
|
||||
|
||||
|
||||
def sample_inputs_rbinops(op_info, device, dtype, requires_grad, supports_dtype_kwargs=True, **kwargs):
|
||||
def _make_tensor_helper(shape, low=None, high=None):
|
||||
return make_tensor(shape, device, dtype, low=low, high=high, requires_grad=requires_grad)
|
||||
|
|
@ -7397,6 +7467,11 @@ op_db: List[OpInfo] = [
|
|||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
sample_inputs_func=sample_inputs_diagonal_diag_embed),
|
||||
OpInfo('diagonal_scatter',
|
||||
dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16),
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
sample_inputs_func=sample_inputs_diagonal_scatter),
|
||||
OpInfo('eq',
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
|
||||
supports_autograd=False,
|
||||
|
|
@ -9278,6 +9353,16 @@ op_db: List[OpInfo] = [
|
|||
assert_jit_shape_analysis=True,
|
||||
supports_forward_ad=True,
|
||||
supports_out=False),
|
||||
OpInfo('select_scatter',
|
||||
dtypes=all_types_and(torch.bfloat16, torch.half, torch.bool),
|
||||
sample_inputs_func=sample_inputs_select_scatter,
|
||||
supports_forward_ad=True,
|
||||
supports_out=False),
|
||||
OpInfo('slice_scatter',
|
||||
dtypes=all_types_and(torch.bfloat16, torch.half, torch.bool),
|
||||
sample_inputs_func=sample_inputs_slice_scatter,
|
||||
supports_forward_ad=True,
|
||||
supports_out=False),
|
||||
UnaryUfuncInfo('signbit',
|
||||
ref=np.signbit,
|
||||
dtypes=all_types_and(torch.bool, torch.bfloat16, torch.half),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user