Revert "Fix unbind_copy and add its decomposition (#134319)"

This reverts commit 9f81270d75.

Reverted https://github.com/pytorch/pytorch/pull/134319 on behalf of https://github.com/clee2000 due to breaking some executorch tests D64568664 ([comment](https://github.com/pytorch/pytorch/pull/134319#issuecomment-2423157700))
This commit is contained in:
PyTorch MergeBot 2024-10-18 20:09:40 +00:00
parent cd1e9b0e60
commit 7b39fb5712
13 changed files with 32 additions and 110 deletions

View File

@ -26,7 +26,6 @@
#include <ATen/native/cpu/SerialStackImpl.h> #include <ATen/native/cpu/SerialStackImpl.h>
#include <ATen/native/cpu/StackKernel.h> #include <ATen/native/cpu/StackKernel.h>
#include <ATen/quantized/QTensorImpl.h> #include <ATen/quantized/QTensorImpl.h>
#include <c10/core/GradMode.h>
#include <c10/util/Exception.h> #include <c10/util/Exception.h>
#include <optional> #include <optional>
#include <c10/util/SmallVector.h> #include <c10/util/SmallVector.h>
@ -4047,41 +4046,29 @@ void split_copy_Tensor_out(const at::Tensor & self, int64_t split_size, int64_t
} }
} }
namespace {
void copy_tensor_array_to_out(const char* name, const std::vector<Tensor>& array, at::TensorList out) {
TORCH_CHECK(out.size() == array.size(), name, " expected an out= argument of size ", array.size(), ", got size ", out.size());
for (const auto i : c10::irange(out.size())) {
if (resize_output_check(out[i], array[i].sizes())) {
out[i].resize_(array[i].sizes());
}
TORCH_CHECK(out[i].dtype() == array[i].dtype(),
"Expected out tensor to have dtype ", array[i].dtype(), ", but got ", out[i].dtype(), " instead");
TORCH_CHECK(out[i].device() == array[i].device(),
"Expected out tensor to have device ", array[i].device(), ", but got ", out[i].device(), " instead");
out[i].copy_(array[i]);
}
}
}
void split_with_sizes_copy_out(const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim, at::TensorList out) { void split_with_sizes_copy_out(const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim, at::TensorList out) {
auto tmp = self.split_with_sizes(split_sizes, dim); auto tmp = self.split_with_sizes(split_sizes, dim);
copy_tensor_array_to_out("split_with_sizes_copy_out()", tmp, out);
TORCH_CHECK(out.size() == tmp.size(), "split_with_sizes_copy_out() expected an out= argument of size ", tmp.size(), ", got size ", out.size());
for (const auto i : c10::irange(out.size())) {
if (resize_output_check(out[i], tmp[i].sizes())) {
out[i].resize_(tmp[i].sizes());
}
TORCH_CHECK(out[i].dtype() == tmp[i].dtype(),
"Expected out tensor to have dtype ", tmp[i].dtype(), ", but got ", out[i].dtype(), " instead");
TORCH_CHECK(out[i].device() == tmp[i].device(),
"Expected out tensor to have device ", tmp[i].device(), ", but got ", out[i].device(), " instead");
out[i].copy_(tmp[i]);
}
} }
void unbind_copy_int_out(const at::Tensor & self, int64_t dim, at::TensorList out) { void unbind_copy_int_out(const at::Tensor & self, int64_t dim, at::TensorList out) {
if (at::GradMode::is_enabled()) {
for (const auto i : c10::irange(out.size())) {
TORCH_CHECK(!out[i].requires_grad(),
"unbind_copy(): functions with out=... arguments don't support automatic differentiation, "
"but one of the arguments requires grad."
);
}
}
auto tmp = self.unbind(dim); auto tmp = self.unbind(dim);
copy_tensor_array_to_out("unbind_copy_int_out()", tmp, out);
TORCH_CHECK(out.size() == tmp.size(), "unbind_copy_int_out() expected an out= argument of size ", tmp.size(), ", got size ", out.size());
for (const auto i : c10::irange(out.size())) {
out[i].copy_(tmp[i]);
}
} }
int64_t sparse_dim_default(const Tensor& self) { int64_t sparse_dim_default(const Tensor& self) {

View File

@ -449,7 +449,6 @@ dtensor_fails = {
xfail("trapz"), xfail("trapz"),
xfail("triangular_solve"), xfail("triangular_solve"),
xfail("unbind"), xfail("unbind"),
xfail("unbind_copy"),
xfail("unfold"), xfail("unfold"),
xfail("unfold_copy"), xfail("unfold_copy"),
xfail("uniform"), xfail("uniform"),

View File

@ -504,7 +504,6 @@ aten::triu_indices.out
aten::trunc aten::trunc
aten::trunc.out aten::trunc.out
aten::trunc_ aten::trunc_
aten::unbind_copy.int_out
aten::unfold aten::unfold
aten::uniform aten::uniform
aten::uniform.out aten::uniform.out

View File

@ -1294,6 +1294,8 @@ aten::topk.values
aten::transpose_ aten::transpose_
aten::triangular_solve aten::triangular_solve
aten::triangular_solve.X aten::triangular_solve.X
aten::unbind_copy.int
aten::unbind_copy.int_out
aten::unique_consecutive aten::unique_consecutive
aten::unique_consecutive.out aten::unique_consecutive.out
aten::unique_dim aten::unique_dim

View File

@ -1038,9 +1038,6 @@ class TestOperators(TestCase):
xfail("_native_batch_norm_legit"), xfail("_native_batch_norm_legit"),
# TODO: implement batching rule # TODO: implement batching rule
xfail("_batch_norm_with_update"), xfail("_batch_norm_with_update"),
xfail(
"unbind_copy"
), # Batching rule not implemented for aten::unbind_copy.int.
} }
), ),
) )
@ -1180,9 +1177,6 @@ class TestOperators(TestCase):
xfail("sparse.mm", "reduce"), xfail("sparse.mm", "reduce"),
xfail("as_strided_scatter", ""), # calls as_strided xfail("as_strided_scatter", ""), # calls as_strided
xfail("index_reduce", "prod"), # .item() call xfail("index_reduce", "prod"), # .item() call
xfail(
"unbind_copy"
), # Batching rule not implemented for aten::unbind_copy.int.
# --------------------------------------------------------------------- # ---------------------------------------------------------------------
} }
) )
@ -1321,9 +1315,6 @@ class TestOperators(TestCase):
xfail("_native_batch_norm_legit"), xfail("_native_batch_norm_legit"),
# TODO: implement batching rule # TODO: implement batching rule
xfail("_batch_norm_with_update"), xfail("_batch_norm_with_update"),
xfail(
"unbind_copy"
), # Batching rule not implemented for aten::unbind_copy.int.
# ---------------------------------------------------------------------- # ----------------------------------------------------------------------
} }
@ -1635,9 +1626,6 @@ class TestOperators(TestCase):
xfail("__getitem__", ""), xfail("__getitem__", ""),
xfail("index_put", ""), xfail("index_put", ""),
xfail("view_as_complex"), xfail("view_as_complex"),
xfail(
"unbind_copy"
), # Batching rule not implemented for aten::unbind_copy.int.
xfail("nn.functional.gaussian_nll_loss"), xfail("nn.functional.gaussian_nll_loss"),
xfail("masked_select"), xfail("masked_select"),
xfail( xfail(
@ -1932,9 +1920,6 @@ class TestOperators(TestCase):
xfail( xfail(
"as_strided_scatter" "as_strided_scatter"
), # AssertionError: Tensor-likes are not close! ), # AssertionError: Tensor-likes are not close!
xfail(
"unbind_copy"
), # Batching rule not implemented for aten::unbind_copy.int.
xfail("bernoulli"), # calls random op xfail("bernoulli"), # calls random op
xfail("bfloat16"), # required rank 4 tensor to use channels_last format xfail("bfloat16"), # required rank 4 tensor to use channels_last format
xfail("cdist"), # Forward AD not implemented and no decomposition xfail("cdist"), # Forward AD not implemented and no decomposition

View File

@ -4375,9 +4375,6 @@ class TestVmapOperatorsOpInfo(TestCase):
xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints
# TypeError: expected Tensor as element 0 in argument 0, but got float # TypeError: expected Tensor as element 0 in argument 0, but got float
xfail("item"), xfail("item"),
xfail(
"unbind_copy"
), # Batching rule not implemented for aten::unbind_copy.int.
} }
), ),
) )
@ -4453,9 +4450,6 @@ class TestVmapOperatorsOpInfo(TestCase):
xfail("item"), xfail("item"),
xfail("tril"), # Exception not raised on error input xfail("tril"), # Exception not raised on error input
xfail("triu"), # Exception not raised on error input xfail("triu"), # Exception not raised on error input
xfail(
"unbind_copy"
), # Batching rule not implemented for aten::unbind_copy.int.
xfail("__getitem__", ""), xfail("__getitem__", ""),
xfail("count_nonzero"), xfail("count_nonzero"),
xfail( xfail(

View File

@ -349,7 +349,6 @@ def mps_ops_modifier(ops):
'transpose_copy', 'transpose_copy',
'T', 'T',
'unbind', 'unbind',
'unbind_copy',
'unflatten', 'unflatten',
'unfold', 'unfold',
'unfold_copy', 'unfold_copy',

View File

@ -240,7 +240,6 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
"slice", "slice",
"constant_pad_nd", "constant_pad_nd",
"unbind", "unbind",
"unbind_copy",
"split", "split",
"split_with_sizes", "split_with_sizes",
"unsafe_split", "unsafe_split",

View File

@ -513,7 +513,6 @@ def _core_aten_decompositions_post_autograd() -> (
aten.triu, aten.triu,
aten.triu_, aten.triu_,
aten.unbind, aten.unbind,
aten.unbind_copy.int,
aten.unfold_backward, aten.unfold_backward,
aten.unfold_copy, aten.unfold_copy,
aten._unsafe_index, aten._unsafe_index,

View File

@ -129,8 +129,6 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode):
func = torch._decomp.decomposition_table.get(orig_func, None) func = torch._decomp.decomposition_table.get(orig_func, None)
elif func is None and isinstance(orig_func, torch._ops.OpOverloadPacket): elif func is None and isinstance(orig_func, torch._ops.OpOverloadPacket):
default = getattr(orig_func, "default", None) default = getattr(orig_func, "default", None)
if default is None and orig_func._dir:
default = getattr(orig_func, orig_func._dir[0], None)
if default is not None: if default is not None:
func = torch._decomp.decomposition_table.get(default, None) func = torch._decomp.decomposition_table.get(default, None)

View File

@ -2,16 +2,7 @@
import inspect import inspect
import warnings import warnings
from functools import wraps from functools import wraps
from typing import ( from typing import Callable, NamedTuple, Optional, overload, Sequence, Tuple, TypeVar
Callable,
List,
NamedTuple,
Optional,
overload,
Sequence,
Tuple,
TypeVar,
)
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
import torch import torch
@ -297,17 +288,11 @@ def out_wrapper(
else: else:
result = fn(*args, **kwargs) result = fn(*args, **kwargs)
assert ( assert (
(isinstance(result, TensorLike) and is_tensor) isinstance(result, TensorLike)
or ( and is_tensor
isinstance(result, Tuple) # type: ignore[arg-type] or isinstance(result, Tuple) # type: ignore[arg-type]
and len(result) == len(out_names) # type: ignore[arg-type] and len(result) == len(out_names) # type: ignore[arg-type]
) )
or (
fn.__name__ == "unbind"
and isinstance(result, (List, Tuple)) # type: ignore[arg-type]
)
)
# unbind_copy is a special case: see https://github.com/pytorch/pytorch/issues/130829
if out is not None: if out is not None:
# Naively you might expect this assert to be true, but # Naively you might expect this assert to be true, but
# it's not: # it's not:
@ -325,7 +310,7 @@ def out_wrapper(
# the output tensor, but not the result--which will # the output tensor, but not the result--which will
# be a normal meta tensor, but this is perfectly # be a normal meta tensor, but this is perfectly
# harmless. # harmless.
if is_tensor and fn.__name__ != "unbind": if is_tensor:
assert isinstance(out, TensorLike) assert isinstance(out, TensorLike)
# These two operations are done in-place # These two operations are done in-place
_maybe_resize_out( _maybe_resize_out(
@ -333,10 +318,7 @@ def out_wrapper(
) )
_safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type] _safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type]
else: else:
if fn.__name__ != "unbind":
assert isinstance(out, Tuple) # type: ignore[arg-type] assert isinstance(out, Tuple) # type: ignore[arg-type]
else:
assert isinstance(out, (List, Tuple)) # type: ignore[arg-type]
torch._check_type( torch._check_type(
len(out) == len(result), # type: ignore[arg-type] len(out) == len(result), # type: ignore[arg-type]
lambda: f"expected tuple of {len(result)} elements but got {len(out)}", # type: ignore[arg-type] lambda: f"expected tuple of {len(result)} elements but got {len(out)}", # type: ignore[arg-type]

View File

@ -304,7 +304,6 @@ __all__ = [
"tensor_split", "tensor_split",
"transpose", "transpose",
"transpose_copy", "transpose_copy",
"unbind_copy",
"unfold", "unfold",
"unfold_copy", "unfold_copy",
"unsqueeze", "unsqueeze",
@ -6381,7 +6380,6 @@ narrow_copy = _make_copy_from_view(aten.narrow)
squeeze_copy = _make_copy_from_view(aten.squeeze) squeeze_copy = _make_copy_from_view(aten.squeeze)
t_copy = _make_copy_from_view(aten.t) t_copy = _make_copy_from_view(aten.t)
transpose_copy = _make_copy_from_view(aten.transpose) transpose_copy = _make_copy_from_view(aten.transpose)
unbind_copy = _make_copy_from_view(aten.unbind)
unsqueeze_copy = _make_copy_from_view(aten.unsqueeze) unsqueeze_copy = _make_copy_from_view(aten.unsqueeze)
view_copy = _make_copy_from_view(aten.view) view_copy = _make_copy_from_view(aten.view)

View File

@ -19455,25 +19455,6 @@ op_db: List[OpInfo] = [
supports_gradgrad=True, supports_gradgrad=True,
supports_out=False, supports_out=False,
), ),
OpInfo('unbind_copy',
dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
ref=reference_unbind,
sample_inputs_func=sample_inputs_unbind,
error_inputs_func=error_inputs_unbind,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_gradgrad=True,
supports_out=True,
check_batched_grad=False,
skips=(
# Expected __torch_dispatch__ for aten::unbind_copy.int_out to return None
# but it returned something else instead.
DecorateInfo(
unittest.expectedFailure,
'TestProxyTensorOpInfo',
'test_make_fx_symbolic_exhaustive_out'
),
)),
OpInfo('vstack', OpInfo('vstack',
aliases=('row_stack',), aliases=('row_stack',),
dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
@ -24070,6 +24051,10 @@ python_ref_db = [
PythonRefInfo( PythonRefInfo(
"_refs.transpose_copy", "_refs.transpose_copy",
torch_opinfo_name="transpose_copy", torch_opinfo_name="transpose_copy",
skips=(
# RuntimeError: no _refs support for torch.Tensor.is_conj
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
),
supports_out=True, supports_out=True,
), ),
PythonRefInfo( PythonRefInfo(
@ -24086,10 +24071,6 @@ python_ref_db = [
torch_opinfo_name="T", torch_opinfo_name="T",
error_inputs_func=partial(error_inputs_T, has_ndims_error=True), error_inputs_func=partial(error_inputs_T, has_ndims_error=True),
), ),
PythonRefInfo(
"_refs.unbind_copy",
torch_opinfo_name="unbind_copy",
),
PythonRefInfo( PythonRefInfo(
"_refs.unfold", "_refs.unfold",
torch_opinfo_name="unfold", torch_opinfo_name="unfold",