mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Fix unbind_copy and add its decomposition (#134319)"
This reverts commit 9f81270d75.
Reverted https://github.com/pytorch/pytorch/pull/134319 on behalf of https://github.com/clee2000 due to breaking some executorch tests D64568664 ([comment](https://github.com/pytorch/pytorch/pull/134319#issuecomment-2423157700))
This commit is contained in:
parent
cd1e9b0e60
commit
7b39fb5712
|
|
@ -26,7 +26,6 @@
|
||||||
#include <ATen/native/cpu/SerialStackImpl.h>
|
#include <ATen/native/cpu/SerialStackImpl.h>
|
||||||
#include <ATen/native/cpu/StackKernel.h>
|
#include <ATen/native/cpu/StackKernel.h>
|
||||||
#include <ATen/quantized/QTensorImpl.h>
|
#include <ATen/quantized/QTensorImpl.h>
|
||||||
#include <c10/core/GradMode.h>
|
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <c10/util/SmallVector.h>
|
#include <c10/util/SmallVector.h>
|
||||||
|
|
@ -4047,41 +4046,29 @@ void split_copy_Tensor_out(const at::Tensor & self, int64_t split_size, int64_t
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
void copy_tensor_array_to_out(const char* name, const std::vector<Tensor>& array, at::TensorList out) {
|
|
||||||
TORCH_CHECK(out.size() == array.size(), name, " expected an out= argument of size ", array.size(), ", got size ", out.size());
|
|
||||||
for (const auto i : c10::irange(out.size())) {
|
|
||||||
if (resize_output_check(out[i], array[i].sizes())) {
|
|
||||||
out[i].resize_(array[i].sizes());
|
|
||||||
}
|
|
||||||
TORCH_CHECK(out[i].dtype() == array[i].dtype(),
|
|
||||||
"Expected out tensor to have dtype ", array[i].dtype(), ", but got ", out[i].dtype(), " instead");
|
|
||||||
TORCH_CHECK(out[i].device() == array[i].device(),
|
|
||||||
"Expected out tensor to have device ", array[i].device(), ", but got ", out[i].device(), " instead");
|
|
||||||
out[i].copy_(array[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
void split_with_sizes_copy_out(const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim, at::TensorList out) {
|
void split_with_sizes_copy_out(const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim, at::TensorList out) {
|
||||||
auto tmp = self.split_with_sizes(split_sizes, dim);
|
auto tmp = self.split_with_sizes(split_sizes, dim);
|
||||||
copy_tensor_array_to_out("split_with_sizes_copy_out()", tmp, out);
|
|
||||||
|
TORCH_CHECK(out.size() == tmp.size(), "split_with_sizes_copy_out() expected an out= argument of size ", tmp.size(), ", got size ", out.size());
|
||||||
|
for (const auto i : c10::irange(out.size())) {
|
||||||
|
if (resize_output_check(out[i], tmp[i].sizes())) {
|
||||||
|
out[i].resize_(tmp[i].sizes());
|
||||||
|
}
|
||||||
|
TORCH_CHECK(out[i].dtype() == tmp[i].dtype(),
|
||||||
|
"Expected out tensor to have dtype ", tmp[i].dtype(), ", but got ", out[i].dtype(), " instead");
|
||||||
|
TORCH_CHECK(out[i].device() == tmp[i].device(),
|
||||||
|
"Expected out tensor to have device ", tmp[i].device(), ", but got ", out[i].device(), " instead");
|
||||||
|
out[i].copy_(tmp[i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void unbind_copy_int_out(const at::Tensor & self, int64_t dim, at::TensorList out) {
|
void unbind_copy_int_out(const at::Tensor & self, int64_t dim, at::TensorList out) {
|
||||||
if (at::GradMode::is_enabled()) {
|
|
||||||
for (const auto i : c10::irange(out.size())) {
|
|
||||||
TORCH_CHECK(!out[i].requires_grad(),
|
|
||||||
"unbind_copy(): functions with out=... arguments don't support automatic differentiation, "
|
|
||||||
"but one of the arguments requires grad."
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto tmp = self.unbind(dim);
|
auto tmp = self.unbind(dim);
|
||||||
copy_tensor_array_to_out("unbind_copy_int_out()", tmp, out);
|
|
||||||
|
TORCH_CHECK(out.size() == tmp.size(), "unbind_copy_int_out() expected an out= argument of size ", tmp.size(), ", got size ", out.size());
|
||||||
|
for (const auto i : c10::irange(out.size())) {
|
||||||
|
out[i].copy_(tmp[i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t sparse_dim_default(const Tensor& self) {
|
int64_t sparse_dim_default(const Tensor& self) {
|
||||||
|
|
|
||||||
|
|
@ -449,7 +449,6 @@ dtensor_fails = {
|
||||||
xfail("trapz"),
|
xfail("trapz"),
|
||||||
xfail("triangular_solve"),
|
xfail("triangular_solve"),
|
||||||
xfail("unbind"),
|
xfail("unbind"),
|
||||||
xfail("unbind_copy"),
|
|
||||||
xfail("unfold"),
|
xfail("unfold"),
|
||||||
xfail("unfold_copy"),
|
xfail("unfold_copy"),
|
||||||
xfail("uniform"),
|
xfail("uniform"),
|
||||||
|
|
|
||||||
|
|
@ -504,7 +504,6 @@ aten::triu_indices.out
|
||||||
aten::trunc
|
aten::trunc
|
||||||
aten::trunc.out
|
aten::trunc.out
|
||||||
aten::trunc_
|
aten::trunc_
|
||||||
aten::unbind_copy.int_out
|
|
||||||
aten::unfold
|
aten::unfold
|
||||||
aten::uniform
|
aten::uniform
|
||||||
aten::uniform.out
|
aten::uniform.out
|
||||||
|
|
|
||||||
|
|
@ -1294,6 +1294,8 @@ aten::topk.values
|
||||||
aten::transpose_
|
aten::transpose_
|
||||||
aten::triangular_solve
|
aten::triangular_solve
|
||||||
aten::triangular_solve.X
|
aten::triangular_solve.X
|
||||||
|
aten::unbind_copy.int
|
||||||
|
aten::unbind_copy.int_out
|
||||||
aten::unique_consecutive
|
aten::unique_consecutive
|
||||||
aten::unique_consecutive.out
|
aten::unique_consecutive.out
|
||||||
aten::unique_dim
|
aten::unique_dim
|
||||||
|
|
|
||||||
|
|
@ -1038,9 +1038,6 @@ class TestOperators(TestCase):
|
||||||
xfail("_native_batch_norm_legit"),
|
xfail("_native_batch_norm_legit"),
|
||||||
# TODO: implement batching rule
|
# TODO: implement batching rule
|
||||||
xfail("_batch_norm_with_update"),
|
xfail("_batch_norm_with_update"),
|
||||||
xfail(
|
|
||||||
"unbind_copy"
|
|
||||||
), # Batching rule not implemented for aten::unbind_copy.int.
|
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
@ -1180,9 +1177,6 @@ class TestOperators(TestCase):
|
||||||
xfail("sparse.mm", "reduce"),
|
xfail("sparse.mm", "reduce"),
|
||||||
xfail("as_strided_scatter", ""), # calls as_strided
|
xfail("as_strided_scatter", ""), # calls as_strided
|
||||||
xfail("index_reduce", "prod"), # .item() call
|
xfail("index_reduce", "prod"), # .item() call
|
||||||
xfail(
|
|
||||||
"unbind_copy"
|
|
||||||
), # Batching rule not implemented for aten::unbind_copy.int.
|
|
||||||
# ---------------------------------------------------------------------
|
# ---------------------------------------------------------------------
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
@ -1321,9 +1315,6 @@ class TestOperators(TestCase):
|
||||||
xfail("_native_batch_norm_legit"),
|
xfail("_native_batch_norm_legit"),
|
||||||
# TODO: implement batching rule
|
# TODO: implement batching rule
|
||||||
xfail("_batch_norm_with_update"),
|
xfail("_batch_norm_with_update"),
|
||||||
xfail(
|
|
||||||
"unbind_copy"
|
|
||||||
), # Batching rule not implemented for aten::unbind_copy.int.
|
|
||||||
# ----------------------------------------------------------------------
|
# ----------------------------------------------------------------------
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1635,9 +1626,6 @@ class TestOperators(TestCase):
|
||||||
xfail("__getitem__", ""),
|
xfail("__getitem__", ""),
|
||||||
xfail("index_put", ""),
|
xfail("index_put", ""),
|
||||||
xfail("view_as_complex"),
|
xfail("view_as_complex"),
|
||||||
xfail(
|
|
||||||
"unbind_copy"
|
|
||||||
), # Batching rule not implemented for aten::unbind_copy.int.
|
|
||||||
xfail("nn.functional.gaussian_nll_loss"),
|
xfail("nn.functional.gaussian_nll_loss"),
|
||||||
xfail("masked_select"),
|
xfail("masked_select"),
|
||||||
xfail(
|
xfail(
|
||||||
|
|
@ -1932,9 +1920,6 @@ class TestOperators(TestCase):
|
||||||
xfail(
|
xfail(
|
||||||
"as_strided_scatter"
|
"as_strided_scatter"
|
||||||
), # AssertionError: Tensor-likes are not close!
|
), # AssertionError: Tensor-likes are not close!
|
||||||
xfail(
|
|
||||||
"unbind_copy"
|
|
||||||
), # Batching rule not implemented for aten::unbind_copy.int.
|
|
||||||
xfail("bernoulli"), # calls random op
|
xfail("bernoulli"), # calls random op
|
||||||
xfail("bfloat16"), # required rank 4 tensor to use channels_last format
|
xfail("bfloat16"), # required rank 4 tensor to use channels_last format
|
||||||
xfail("cdist"), # Forward AD not implemented and no decomposition
|
xfail("cdist"), # Forward AD not implemented and no decomposition
|
||||||
|
|
|
||||||
|
|
@ -4375,9 +4375,6 @@ class TestVmapOperatorsOpInfo(TestCase):
|
||||||
xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints
|
xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints
|
||||||
# TypeError: expected Tensor as element 0 in argument 0, but got float
|
# TypeError: expected Tensor as element 0 in argument 0, but got float
|
||||||
xfail("item"),
|
xfail("item"),
|
||||||
xfail(
|
|
||||||
"unbind_copy"
|
|
||||||
), # Batching rule not implemented for aten::unbind_copy.int.
|
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
@ -4453,9 +4450,6 @@ class TestVmapOperatorsOpInfo(TestCase):
|
||||||
xfail("item"),
|
xfail("item"),
|
||||||
xfail("tril"), # Exception not raised on error input
|
xfail("tril"), # Exception not raised on error input
|
||||||
xfail("triu"), # Exception not raised on error input
|
xfail("triu"), # Exception not raised on error input
|
||||||
xfail(
|
|
||||||
"unbind_copy"
|
|
||||||
), # Batching rule not implemented for aten::unbind_copy.int.
|
|
||||||
xfail("__getitem__", ""),
|
xfail("__getitem__", ""),
|
||||||
xfail("count_nonzero"),
|
xfail("count_nonzero"),
|
||||||
xfail(
|
xfail(
|
||||||
|
|
|
||||||
|
|
@ -349,7 +349,6 @@ def mps_ops_modifier(ops):
|
||||||
'transpose_copy',
|
'transpose_copy',
|
||||||
'T',
|
'T',
|
||||||
'unbind',
|
'unbind',
|
||||||
'unbind_copy',
|
|
||||||
'unflatten',
|
'unflatten',
|
||||||
'unfold',
|
'unfold',
|
||||||
'unfold_copy',
|
'unfold_copy',
|
||||||
|
|
|
||||||
|
|
@ -240,7 +240,6 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
|
||||||
"slice",
|
"slice",
|
||||||
"constant_pad_nd",
|
"constant_pad_nd",
|
||||||
"unbind",
|
"unbind",
|
||||||
"unbind_copy",
|
|
||||||
"split",
|
"split",
|
||||||
"split_with_sizes",
|
"split_with_sizes",
|
||||||
"unsafe_split",
|
"unsafe_split",
|
||||||
|
|
|
||||||
|
|
@ -513,7 +513,6 @@ def _core_aten_decompositions_post_autograd() -> (
|
||||||
aten.triu,
|
aten.triu,
|
||||||
aten.triu_,
|
aten.triu_,
|
||||||
aten.unbind,
|
aten.unbind,
|
||||||
aten.unbind_copy.int,
|
|
||||||
aten.unfold_backward,
|
aten.unfold_backward,
|
||||||
aten.unfold_copy,
|
aten.unfold_copy,
|
||||||
aten._unsafe_index,
|
aten._unsafe_index,
|
||||||
|
|
|
||||||
|
|
@ -129,8 +129,6 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode):
|
||||||
func = torch._decomp.decomposition_table.get(orig_func, None)
|
func = torch._decomp.decomposition_table.get(orig_func, None)
|
||||||
elif func is None and isinstance(orig_func, torch._ops.OpOverloadPacket):
|
elif func is None and isinstance(orig_func, torch._ops.OpOverloadPacket):
|
||||||
default = getattr(orig_func, "default", None)
|
default = getattr(orig_func, "default", None)
|
||||||
if default is None and orig_func._dir:
|
|
||||||
default = getattr(orig_func, orig_func._dir[0], None)
|
|
||||||
if default is not None:
|
if default is not None:
|
||||||
func = torch._decomp.decomposition_table.get(default, None)
|
func = torch._decomp.decomposition_table.get(default, None)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,16 +2,7 @@
|
||||||
import inspect
|
import inspect
|
||||||
import warnings
|
import warnings
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import (
|
from typing import Callable, NamedTuple, Optional, overload, Sequence, Tuple, TypeVar
|
||||||
Callable,
|
|
||||||
List,
|
|
||||||
NamedTuple,
|
|
||||||
Optional,
|
|
||||||
overload,
|
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
TypeVar,
|
|
||||||
)
|
|
||||||
from typing_extensions import ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -297,17 +288,11 @@ def out_wrapper(
|
||||||
else:
|
else:
|
||||||
result = fn(*args, **kwargs)
|
result = fn(*args, **kwargs)
|
||||||
assert (
|
assert (
|
||||||
(isinstance(result, TensorLike) and is_tensor)
|
isinstance(result, TensorLike)
|
||||||
or (
|
and is_tensor
|
||||||
isinstance(result, Tuple) # type: ignore[arg-type]
|
or isinstance(result, Tuple) # type: ignore[arg-type]
|
||||||
and len(result) == len(out_names) # type: ignore[arg-type]
|
and len(result) == len(out_names) # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
or (
|
|
||||||
fn.__name__ == "unbind"
|
|
||||||
and isinstance(result, (List, Tuple)) # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# unbind_copy is a special case: see https://github.com/pytorch/pytorch/issues/130829
|
|
||||||
if out is not None:
|
if out is not None:
|
||||||
# Naively you might expect this assert to be true, but
|
# Naively you might expect this assert to be true, but
|
||||||
# it's not:
|
# it's not:
|
||||||
|
|
@ -325,7 +310,7 @@ def out_wrapper(
|
||||||
# the output tensor, but not the result--which will
|
# the output tensor, but not the result--which will
|
||||||
# be a normal meta tensor, but this is perfectly
|
# be a normal meta tensor, but this is perfectly
|
||||||
# harmless.
|
# harmless.
|
||||||
if is_tensor and fn.__name__ != "unbind":
|
if is_tensor:
|
||||||
assert isinstance(out, TensorLike)
|
assert isinstance(out, TensorLike)
|
||||||
# These two operations are done in-place
|
# These two operations are done in-place
|
||||||
_maybe_resize_out(
|
_maybe_resize_out(
|
||||||
|
|
@ -333,10 +318,7 @@ def out_wrapper(
|
||||||
)
|
)
|
||||||
_safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type]
|
_safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type]
|
||||||
else:
|
else:
|
||||||
if fn.__name__ != "unbind":
|
|
||||||
assert isinstance(out, Tuple) # type: ignore[arg-type]
|
assert isinstance(out, Tuple) # type: ignore[arg-type]
|
||||||
else:
|
|
||||||
assert isinstance(out, (List, Tuple)) # type: ignore[arg-type]
|
|
||||||
torch._check_type(
|
torch._check_type(
|
||||||
len(out) == len(result), # type: ignore[arg-type]
|
len(out) == len(result), # type: ignore[arg-type]
|
||||||
lambda: f"expected tuple of {len(result)} elements but got {len(out)}", # type: ignore[arg-type]
|
lambda: f"expected tuple of {len(result)} elements but got {len(out)}", # type: ignore[arg-type]
|
||||||
|
|
|
||||||
|
|
@ -304,7 +304,6 @@ __all__ = [
|
||||||
"tensor_split",
|
"tensor_split",
|
||||||
"transpose",
|
"transpose",
|
||||||
"transpose_copy",
|
"transpose_copy",
|
||||||
"unbind_copy",
|
|
||||||
"unfold",
|
"unfold",
|
||||||
"unfold_copy",
|
"unfold_copy",
|
||||||
"unsqueeze",
|
"unsqueeze",
|
||||||
|
|
@ -6381,7 +6380,6 @@ narrow_copy = _make_copy_from_view(aten.narrow)
|
||||||
squeeze_copy = _make_copy_from_view(aten.squeeze)
|
squeeze_copy = _make_copy_from_view(aten.squeeze)
|
||||||
t_copy = _make_copy_from_view(aten.t)
|
t_copy = _make_copy_from_view(aten.t)
|
||||||
transpose_copy = _make_copy_from_view(aten.transpose)
|
transpose_copy = _make_copy_from_view(aten.transpose)
|
||||||
unbind_copy = _make_copy_from_view(aten.unbind)
|
|
||||||
unsqueeze_copy = _make_copy_from_view(aten.unsqueeze)
|
unsqueeze_copy = _make_copy_from_view(aten.unsqueeze)
|
||||||
view_copy = _make_copy_from_view(aten.view)
|
view_copy = _make_copy_from_view(aten.view)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19455,25 +19455,6 @@ op_db: List[OpInfo] = [
|
||||||
supports_gradgrad=True,
|
supports_gradgrad=True,
|
||||||
supports_out=False,
|
supports_out=False,
|
||||||
),
|
),
|
||||||
OpInfo('unbind_copy',
|
|
||||||
dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
|
|
||||||
ref=reference_unbind,
|
|
||||||
sample_inputs_func=sample_inputs_unbind,
|
|
||||||
error_inputs_func=error_inputs_unbind,
|
|
||||||
supports_forward_ad=True,
|
|
||||||
supports_fwgrad_bwgrad=True,
|
|
||||||
supports_gradgrad=True,
|
|
||||||
supports_out=True,
|
|
||||||
check_batched_grad=False,
|
|
||||||
skips=(
|
|
||||||
# Expected __torch_dispatch__ for aten::unbind_copy.int_out to return None
|
|
||||||
# but it returned something else instead.
|
|
||||||
DecorateInfo(
|
|
||||||
unittest.expectedFailure,
|
|
||||||
'TestProxyTensorOpInfo',
|
|
||||||
'test_make_fx_symbolic_exhaustive_out'
|
|
||||||
),
|
|
||||||
)),
|
|
||||||
OpInfo('vstack',
|
OpInfo('vstack',
|
||||||
aliases=('row_stack',),
|
aliases=('row_stack',),
|
||||||
dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
|
dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
|
||||||
|
|
@ -24070,6 +24051,10 @@ python_ref_db = [
|
||||||
PythonRefInfo(
|
PythonRefInfo(
|
||||||
"_refs.transpose_copy",
|
"_refs.transpose_copy",
|
||||||
torch_opinfo_name="transpose_copy",
|
torch_opinfo_name="transpose_copy",
|
||||||
|
skips=(
|
||||||
|
# RuntimeError: no _refs support for torch.Tensor.is_conj
|
||||||
|
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
|
||||||
|
),
|
||||||
supports_out=True,
|
supports_out=True,
|
||||||
),
|
),
|
||||||
PythonRefInfo(
|
PythonRefInfo(
|
||||||
|
|
@ -24086,10 +24071,6 @@ python_ref_db = [
|
||||||
torch_opinfo_name="T",
|
torch_opinfo_name="T",
|
||||||
error_inputs_func=partial(error_inputs_T, has_ndims_error=True),
|
error_inputs_func=partial(error_inputs_T, has_ndims_error=True),
|
||||||
),
|
),
|
||||||
PythonRefInfo(
|
|
||||||
"_refs.unbind_copy",
|
|
||||||
torch_opinfo_name="unbind_copy",
|
|
||||||
),
|
|
||||||
PythonRefInfo(
|
PythonRefInfo(
|
||||||
"_refs.unfold",
|
"_refs.unfold",
|
||||||
torch_opinfo_name="unfold",
|
torch_opinfo_name="unfold",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user