mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Decompose/add reference for view_as_complex (#108005)
Aten source: d4a99631dd/aten/src/ATen/native/ComplexHelper.h (L78)
Documentation reference:
https://pytorch.org/docs/stable/generated/torch.view_as_complex.html
Note: this adds a new primitive `view_of_dtype`, which is trivially implemented, as its meta function is already implemented elsewhere.
Finally, this is not registered as a decomposition (yet), because TorchInductor does not yet support complex types. It should be added once we do.
Closes https://github.com/pytorch/pytorch/issues/108020 as well.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108005
Approved by: https://github.com/peterbell10, https://github.com/ezyang
This commit is contained in:
parent
366ce589d0
commit
c458fa0d35
|
|
@ -695,7 +695,7 @@ Tensor sparse_compressed_to_dense(
|
|||
|
||||
// Computes the strides for view_dtype output when the view dtype is
|
||||
// smaller than the original dtype
|
||||
inline DimVector compute_strides_for_view_dtype_downsize(IntArrayRef old_strides, int64_t size_ratio, ScalarType old_dtype, ScalarType new_dtype) {
|
||||
inline SymDimVector compute_strides_for_view_dtype_downsize(SymIntArrayRef old_strides, int64_t size_ratio, ScalarType old_dtype, ScalarType new_dtype) {
|
||||
const int64_t ndim = old_strides.size();
|
||||
|
||||
TORCH_CHECK(
|
||||
|
|
@ -703,7 +703,7 @@ inline DimVector compute_strides_for_view_dtype_downsize(IntArrayRef old_strides
|
|||
"self.stride(-1) must be 1 to view ", old_dtype, " as ", new_dtype,
|
||||
" (different element sizes), but got ", old_strides[ndim - 1]);
|
||||
|
||||
DimVector new_strides(ndim);
|
||||
SymDimVector new_strides(ndim);
|
||||
for (int64_t dim_idx = 0; dim_idx < ndim - 1; dim_idx++) {
|
||||
new_strides[dim_idx] = old_strides[dim_idx] * size_ratio;
|
||||
}
|
||||
|
|
@ -713,14 +713,14 @@ inline DimVector compute_strides_for_view_dtype_downsize(IntArrayRef old_strides
|
|||
|
||||
// Computes the strides for view_dtype output when the view dtype is
|
||||
// larger than the original dtype
|
||||
inline DimVector compute_strides_for_view_dtype_upsize(IntArrayRef old_strides, int64_t size_ratio, ScalarType old_dtype, ScalarType new_dtype) {
|
||||
inline SymDimVector compute_strides_for_view_dtype_upsize(SymIntArrayRef old_strides, int64_t size_ratio, ScalarType old_dtype, ScalarType new_dtype) {
|
||||
const int64_t ndim = old_strides.size();
|
||||
TORCH_CHECK(
|
||||
old_strides[ndim - 1] == 1,
|
||||
"self.stride(-1) must be 1 to view ", old_dtype, " as ", new_dtype,
|
||||
" (different element sizes), but got ", old_strides[ndim - 1]);
|
||||
|
||||
DimVector new_strides(ndim);
|
||||
SymDimVector new_strides(ndim);
|
||||
for (int64_t dim_idx = 0; dim_idx < ndim - 1; dim_idx++) {
|
||||
TORCH_CHECK(
|
||||
(old_strides[dim_idx] % size_ratio) == 0,
|
||||
|
|
@ -753,8 +753,7 @@ Tensor view_dtype(const Tensor& self, ScalarType dtype) {
|
|||
auto* impl = new_tensor.unsafeGetTensorImpl();
|
||||
|
||||
if (self_element_size == new_element_size) {
|
||||
impl->set_storage_offset(self.storage_offset());
|
||||
impl->set_sizes_and_strides(self.sizes(), self.strides());
|
||||
impl->set_sizes_and_strides(self.sym_sizes(), self.sym_strides(), self.sym_storage_offset());
|
||||
|
||||
} else if (self.dim() == 0) {
|
||||
TORCH_CHECK(false,
|
||||
|
|
@ -766,17 +765,16 @@ Tensor view_dtype(const Tensor& self, ScalarType dtype) {
|
|||
|
||||
int64_t size_ratio = self_element_size / new_element_size;
|
||||
auto new_strides = compute_strides_for_view_dtype_downsize(
|
||||
self.strides(), size_ratio, self.scalar_type(), dtype);
|
||||
self.sym_strides(), size_ratio, self.scalar_type(), dtype);
|
||||
|
||||
auto old_sizes = self.sizes();
|
||||
DimVector new_sizes(self.dim());
|
||||
auto old_sizes = self.sym_sizes();
|
||||
SymDimVector new_sizes(self.dim());
|
||||
std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin());
|
||||
new_sizes[self.dim() - 1] *= size_ratio;
|
||||
|
||||
auto new_storage_offset = size_ratio * self.storage_offset();
|
||||
auto new_storage_offset = size_ratio * self.sym_storage_offset();
|
||||
|
||||
impl->set_storage_offset(new_storage_offset);
|
||||
impl->set_sizes_and_strides(new_sizes, new_strides);
|
||||
impl->set_sizes_and_strides(new_sizes, new_strides, new_storage_offset);
|
||||
|
||||
} else {
|
||||
// Upsizing element size
|
||||
|
|
@ -784,29 +782,28 @@ Tensor view_dtype(const Tensor& self, ScalarType dtype) {
|
|||
int64_t size_ratio = new_element_size / self_element_size;
|
||||
|
||||
TORCH_CHECK(
|
||||
(self.size(-1) % size_ratio) == 0,
|
||||
(self.sym_size(-1) % size_ratio) == 0,
|
||||
"self.size(-1) must be divisible by ", size_ratio, " to view ",
|
||||
self.scalar_type(), " as ", dtype, " (different element sizes), ",
|
||||
"but got ", self.size(-1));
|
||||
"but got ", self.sym_size(-1));
|
||||
|
||||
TORCH_CHECK(
|
||||
(self.storage_offset() % size_ratio) == 0,
|
||||
(self.sym_storage_offset() % size_ratio) == 0,
|
||||
"self.storage_offset() must be divisible by ", size_ratio, " to view ",
|
||||
self.scalar_type(), " as ", dtype, " (different element sizes), but got ",
|
||||
self.storage_offset());
|
||||
self.sym_storage_offset());
|
||||
|
||||
auto new_strides = compute_strides_for_view_dtype_upsize(
|
||||
self.strides(), size_ratio, self.scalar_type(), dtype);
|
||||
self.sym_strides(), size_ratio, self.scalar_type(), dtype);
|
||||
|
||||
auto old_sizes = self.sizes();
|
||||
DimVector new_sizes(self.dim());
|
||||
auto old_sizes = self.sym_sizes();
|
||||
SymDimVector new_sizes(self.dim());
|
||||
std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin());
|
||||
new_sizes[self.dim() - 1] /= size_ratio;
|
||||
|
||||
auto new_storage_offset = self.storage_offset() / size_ratio;
|
||||
auto new_storage_offset = self.sym_storage_offset() / size_ratio;
|
||||
|
||||
impl->set_storage_offset(new_storage_offset);
|
||||
impl->set_sizes_and_strides(new_sizes, new_strides);
|
||||
impl->set_sizes_and_strides(new_sizes, new_strides, new_storage_offset);
|
||||
}
|
||||
|
||||
return new_tensor;
|
||||
|
|
|
|||
|
|
@ -1876,6 +1876,7 @@ class TestRefsOpsInfo(TestCase):
|
|||
'_refs.imag',
|
||||
'_refs.reshape_as',
|
||||
'_refs.view_as',
|
||||
'_refs.view_as_complex' # TorchInductor does not support complex at the moment.
|
||||
}
|
||||
|
||||
@parametrize("op", ref_ops_names)
|
||||
|
|
|
|||
|
|
@ -373,6 +373,7 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
|
|||
aten._unsafe_index,
|
||||
aten.upsample_bilinear2d,
|
||||
aten.upsample_nearest2d_backward,
|
||||
aten.view_as_complex,
|
||||
aten.xlogy,
|
||||
aten.xlogy_,
|
||||
aten.zero,
|
||||
|
|
|
|||
|
|
@ -1886,6 +1886,7 @@ make_fallback(aten.topk)
|
|||
make_fallback(aten.upsample_bicubic2d_backward, require_contiguous)
|
||||
make_fallback(aten._scaled_mm.default)
|
||||
|
||||
# TODO: This is done, just need to enable support in TorchInductor for complex types.
|
||||
make_fallback(aten.view_as_complex, require_contiguous)
|
||||
|
||||
# The following were added as a result of https://github.com/pytorch/pytorch/pull/94039 to pass tests
|
||||
|
|
|
|||
|
|
@ -2642,11 +2642,6 @@ def meta_complex(real, imag):
|
|||
return real.new_empty(out_shape, dtype=corresponding_complex_dtype(real.dtype))
|
||||
|
||||
|
||||
@register_meta(aten.view.dtype)
|
||||
def view_dtype(self, dtype):
|
||||
return utils.clone_preserve_strides(self).to(dtype)
|
||||
|
||||
|
||||
@register_meta([aten.nonzero_static.default, aten.nonzero_static.out])
|
||||
def nonzero_static(self, *, size: int, fill_value: int = -1):
|
||||
return self.new_empty((size, self.dim()), dtype=torch.long)
|
||||
|
|
|
|||
|
|
@ -148,6 +148,7 @@ __all__ = [
|
|||
"squeeze",
|
||||
"transpose",
|
||||
"view_of",
|
||||
"view_element_type",
|
||||
#
|
||||
# Functionalized view mutations
|
||||
#
|
||||
|
|
@ -172,7 +173,6 @@ __all__ = [
|
|||
"item",
|
||||
"maximum_value",
|
||||
"minimum_value",
|
||||
"to_dtype",
|
||||
"copy_strided",
|
||||
#
|
||||
# Inplace prims
|
||||
|
|
@ -1780,6 +1780,27 @@ view_of = _make_prim(
|
|||
doc=_view_of_doc,
|
||||
)
|
||||
|
||||
|
||||
def _view_element_type_meta(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType:
|
||||
return a.view(dtype)
|
||||
|
||||
|
||||
def _view_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor:
|
||||
return a.view(dtype)
|
||||
|
||||
|
||||
_view_element_type_doc = """
|
||||
Creates a view of the tensor with a different dtype.
|
||||
"""
|
||||
|
||||
view_element_type = _make_prim(
|
||||
schema="view_of_dtype(Tensor(a) a, ScalarType dtype) -> Tensor",
|
||||
meta=_view_element_type_meta,
|
||||
impl_aten=_view_element_type_aten,
|
||||
return_type=RETURN_TYPE.VIEW,
|
||||
doc=_view_element_type_doc,
|
||||
)
|
||||
|
||||
#
|
||||
# Functionalized view mutations
|
||||
#
|
||||
|
|
|
|||
|
|
@ -289,6 +289,7 @@ __all__ = [
|
|||
"view_as",
|
||||
"vsplit",
|
||||
"vstack",
|
||||
"view_as_complex",
|
||||
"unflatten",
|
||||
"unbind",
|
||||
"triu",
|
||||
|
|
@ -949,6 +950,43 @@ def trunc(a):
|
|||
return prims.trunc(a)
|
||||
|
||||
|
||||
# TODO: register this as a real ref/decomposition once TorchInductor supports complex!
|
||||
def view_as_complex(self: TensorLikeType) -> TensorLikeType:
|
||||
input_dtype = self.dtype
|
||||
torch._check(
|
||||
utils.is_float_dtype(input_dtype),
|
||||
lambda: f"view_as_complex is only supported for floating point"
|
||||
f"tensors, but got a tensor of scalar type: {input_dtype}",
|
||||
)
|
||||
sizes = self.size()
|
||||
torch._check(
|
||||
len(sizes) != 0,
|
||||
lambda: "Input tensor must have one or more dimensions",
|
||||
)
|
||||
torch._check(
|
||||
sizes[-1] == 2,
|
||||
lambda: "Tensor must have a last dimension of size 2",
|
||||
)
|
||||
|
||||
old_strides = self.stride()
|
||||
torch._check(
|
||||
old_strides[-1] == 1,
|
||||
lambda: "Tensor must have a last dimension with stride 1",
|
||||
)
|
||||
dims = old_strides[:-1]
|
||||
torch._check(
|
||||
py_all(stride % 2 == 0 for stride in dims),
|
||||
lambda: "Tensor must have a stride divisible by 2 for all but last dimension",
|
||||
)
|
||||
torch._check(
|
||||
self.storage_offset() % 2 == 0,
|
||||
lambda: "Tensor must have a storage_offset divisible by 2",
|
||||
)
|
||||
return prims.view_element_type(
|
||||
self, utils.corresponding_complex_dtype(input_dtype)
|
||||
).squeeze(-1)
|
||||
|
||||
|
||||
def _make_elementwise_binary_reference(
|
||||
type_promotion_kind,
|
||||
aten_op=infer_aten_op,
|
||||
|
|
|
|||
|
|
@ -21576,6 +21576,10 @@ python_ref_db = [
|
|||
),
|
||||
],
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.view_as_complex",
|
||||
torch_opinfo_name="view_as_complex",
|
||||
),
|
||||
]
|
||||
python_ref_db += opinfo.definitions.python_ref_db
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user