mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[primTorch] Fix off by 1 in canonicalize_dim (#83198)
Also fix an issue in the `unsqueeze` ref due to this change. Pull Request resolved: https://github.com/pytorch/pytorch/pull/83198 Approved by: https://github.com/ngimel
This commit is contained in:
parent
6a5ca409da
commit
4010f96121
|
|
@ -7,10 +7,13 @@ int64_t maybe_wrap_dim_slow(
|
||||||
int64_t dim,
|
int64_t dim,
|
||||||
int64_t dim_post_expr,
|
int64_t dim_post_expr,
|
||||||
bool wrap_scalar) {
|
bool wrap_scalar) {
|
||||||
if (dim_post_expr <= 0) {
|
TORCH_CHECK_INDEX(
|
||||||
|
dim_post_expr >= 0, "Rank cannot be negative but got ", dim_post_expr);
|
||||||
|
|
||||||
|
if (dim_post_expr == 0) {
|
||||||
TORCH_CHECK_INDEX(
|
TORCH_CHECK_INDEX(
|
||||||
wrap_scalar,
|
wrap_scalar,
|
||||||
"dimension specified as ",
|
"Dimension specified as ",
|
||||||
dim,
|
dim,
|
||||||
" but tensor has no dimensions");
|
" but tensor has no dimensions");
|
||||||
return c10::maybe_wrap_dim(dim, /*dim_post_expr=*/1, /*wrap_scalar=*/false);
|
return c10::maybe_wrap_dim(dim, /*dim_post_expr=*/1, /*wrap_scalar=*/false);
|
||||||
|
|
|
||||||
|
|
@ -5888,7 +5888,7 @@ class TestTorch(TestCase):
|
||||||
torch.tensor([1]).unflatten(0, [])
|
torch.tensor([1]).unflatten(0, [])
|
||||||
with self.assertRaisesRegex(RuntimeError, r"Provided sizes \[2, 2\] don't multiply up to the size of dim 0 \(1\)"):
|
with self.assertRaisesRegex(RuntimeError, r"Provided sizes \[2, 2\] don't multiply up to the size of dim 0 \(1\)"):
|
||||||
torch.tensor([1]).unflatten(0, [2, 2])
|
torch.tensor([1]).unflatten(0, [2, 2])
|
||||||
with self.assertRaisesRegex(IndexError, r"dimension specified as 0 but tensor has no dimensions"):
|
with self.assertRaisesRegex(IndexError, r"Dimension specified as 0 but tensor has no dimensions"):
|
||||||
torch.tensor(1).unflatten(0, [0])
|
torch.tensor(1).unflatten(0, [0])
|
||||||
with self.assertRaisesRegex(RuntimeError, r"only one dimension can be inferred"):
|
with self.assertRaisesRegex(RuntimeError, r"only one dimension can be inferred"):
|
||||||
torch.randn(5, 10).unflatten(1, (-1, -1))
|
torch.randn(5, 10).unflatten(1, (-1, -1))
|
||||||
|
|
|
||||||
|
|
@ -1384,12 +1384,18 @@ conj = _make_prim(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def expand_dims(a: TensorLikeType, dimensions: DimsSequenceType) -> TensorLikeType:
|
def expand_dims(
|
||||||
|
a: TensorLikeType, dimensions: DimsSequenceType, ndim=None
|
||||||
|
) -> TensorLikeType:
|
||||||
"""
|
"""
|
||||||
Creates a view of a with a.ndim + len(dimensions) dimensions, with new
|
Creates a view of a with a.ndim + len(dimensions) dimensions, with new
|
||||||
dimensions of length one at the dimensions specified by dimensions.
|
dimensions of length one at the dimensions specified by dimensions.
|
||||||
"""
|
"""
|
||||||
dims = sorted(utils.canonicalize_dims(a.ndim, dimensions)) # type: ignore[arg-type]
|
if ndim is not None:
|
||||||
|
# TODO: this is only here to support the unsqueeze ref
|
||||||
|
dims = sorted(utils.canonicalize_dims(ndim, dimensions)) # type: ignore[arg-type]
|
||||||
|
else:
|
||||||
|
dims = sorted(utils.canonicalize_dims(a.ndim, dimensions)) # type: ignore[arg-type]
|
||||||
if len(set(dims)) != len(dims):
|
if len(set(dims)) != len(dims):
|
||||||
msg = "Received duplicate dimensions to expand in {0}".format(str(dimensions))
|
msg = "Received duplicate dimensions to expand in {0}".format(str(dimensions))
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
|
||||||
|
|
@ -442,19 +442,26 @@ def validate_exclusive_idx(rank: int, ex_idx: int):
|
||||||
|
|
||||||
# "Wraps" a dim (up to one time) for the given rank, allowing
|
# "Wraps" a dim (up to one time) for the given rank, allowing
|
||||||
# dims to be specified using negative indices
|
# dims to be specified using negative indices
|
||||||
def canonicalize_dim(rank: int, idx: int) -> int:
|
def canonicalize_dim(rank: int, idx: int, wrap_scalar: bool = True) -> int:
|
||||||
# TODO: add a comment for why this is
|
if rank < 0:
|
||||||
_rank = rank if rank != 0 else 1
|
msg = f"Rank cannot be negative but got {rank}"
|
||||||
|
raise IndexError(msg)
|
||||||
|
|
||||||
if idx >= 0 and idx < _rank:
|
if rank == 0:
|
||||||
|
if not wrap_scalar:
|
||||||
|
msg = f"Dimension specified as {idx} but tensor has no dimensions"
|
||||||
|
raise IndexError(msg)
|
||||||
|
rank = 1
|
||||||
|
|
||||||
|
if idx >= 0 and idx < rank:
|
||||||
return idx
|
return idx
|
||||||
|
|
||||||
if idx < 0:
|
if idx < 0:
|
||||||
_idx = idx + _rank
|
_idx = idx + rank
|
||||||
else:
|
else:
|
||||||
_idx = idx
|
_idx = idx
|
||||||
|
|
||||||
if _idx < 0 or _idx > _rank:
|
if _idx < 0 or _idx >= rank:
|
||||||
# Same error message as in aten/src/ATen/WrapDimUtils.h:49
|
# Same error message as in aten/src/ATen/WrapDimUtils.h:49
|
||||||
msg = "Dimension out of range (expected to be in range of [{0}, {1}], but got {2})".format(
|
msg = "Dimension out of range (expected to be in range of [{0}, {1}], but got {2})".format(
|
||||||
-rank, rank - 1, idx
|
-rank, rank - 1, idx
|
||||||
|
|
|
||||||
|
|
@ -2172,8 +2172,7 @@ def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType:
|
||||||
@out_wrapper()
|
@out_wrapper()
|
||||||
def column_stack(tensors: TensorSequenceType) -> TensorLikeType:
|
def column_stack(tensors: TensorSequenceType) -> TensorLikeType:
|
||||||
aligned_tensors = tuple(
|
aligned_tensors = tuple(
|
||||||
x if x.ndim > 1 else prims.expand_dims(x, list(range(x.ndim, 2)))
|
x if x.ndim > 1 else x.reshape((x.numel(), 1)) for x in tensors
|
||||||
for x in tensors
|
|
||||||
)
|
)
|
||||||
return cat(aligned_tensors, 1)
|
return cat(aligned_tensors, 1)
|
||||||
|
|
||||||
|
|
@ -2684,16 +2683,17 @@ def rot90(
|
||||||
a: TensorLikeType, k: int = 1, dims: DimsSequenceType = (0, 1)
|
a: TensorLikeType, k: int = 1, dims: DimsSequenceType = (0, 1)
|
||||||
) -> TensorLikeType:
|
) -> TensorLikeType:
|
||||||
"""Reference implementation of :func:`torch.rot90`."""
|
"""Reference implementation of :func:`torch.rot90`."""
|
||||||
dims_ = utils.canonicalize_dims(a.ndim, dims)
|
|
||||||
# Required to silence MyPy errors
|
|
||||||
assert isinstance(dims_, (tuple, list))
|
|
||||||
dims = dims_
|
|
||||||
if len(dims) != 2:
|
if len(dims) != 2:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"expected total rotation dims == 2, but got dims = {len(dims)}"
|
f"expected total rotation dims == 2, but got dims = {len(dims)}"
|
||||||
)
|
)
|
||||||
if a.ndim < 2:
|
if a.ndim < 2:
|
||||||
raise RuntimeError(f"expected total dims >= 2, but got total dims = {a.ndim}")
|
raise RuntimeError(f"expected total dims >= 2, but got total dims = {a.ndim}")
|
||||||
|
|
||||||
|
# Do this after the initial checks to be compatible with the behavior in
|
||||||
|
# core.
|
||||||
|
dims = utils.canonicalize_dims(a.ndim, dims)
|
||||||
|
|
||||||
if dims[0] == dims[1]:
|
if dims[0] == dims[1]:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"expected rotation dims to be different, but got dim0 = {dims[0]} and dim1 = {dims[1]}"
|
f"expected rotation dims to be different, but got dim0 = {dims[0]} and dim1 = {dims[1]}"
|
||||||
|
|
@ -3029,8 +3029,9 @@ swap_axes = transpose
|
||||||
def unsqueeze(a: TensorLikeType, dim: int) -> TensorLikeType:
|
def unsqueeze(a: TensorLikeType, dim: int) -> TensorLikeType:
|
||||||
# Note that unsqueeze canonicalizes with rank + 1 because it allows
|
# Note that unsqueeze canonicalizes with rank + 1 because it allows
|
||||||
# a new innermost dimension to be specified
|
# a new innermost dimension to be specified
|
||||||
dim = utils.canonicalize_dim(a.ndim + 1, dim)
|
ndim = a.ndim + 1
|
||||||
return prims.expand_dims(a, (dim,))
|
dim = utils.canonicalize_dim(ndim, dim)
|
||||||
|
return prims.expand_dims(a, (dim,), ndim=ndim)
|
||||||
|
|
||||||
|
|
||||||
# NOTE: shape is a vararg because Tensor.reshape can be called with as
|
# NOTE: shape is a vararg because Tensor.reshape can be called with as
|
||||||
|
|
|
||||||
|
|
@ -2362,7 +2362,7 @@ def sample_inputs_unbind(op_info, device, dtype, requires_grad, **kwargs):
|
||||||
def error_inputs_unbind(op_info, device):
|
def error_inputs_unbind(op_info, device):
|
||||||
make_arg = partial(make_tensor, dtype=torch.int32, device=device, requires_grad=False)
|
make_arg = partial(make_tensor, dtype=torch.int32, device=device, requires_grad=False)
|
||||||
yield ErrorInput(SampleInput(make_arg(()), args=(0,)), error_type=IndexError,
|
yield ErrorInput(SampleInput(make_arg(()), args=(0,)), error_type=IndexError,
|
||||||
error_regex="dimension specified as 0 but tensor has no dimensions")
|
error_regex="Dimension specified as 0 but tensor has no dimensions")
|
||||||
yield ErrorInput(SampleInput(make_arg((2,)), args=(2,)), error_type=IndexError,
|
yield ErrorInput(SampleInput(make_arg((2,)), args=(2,)), error_type=IndexError,
|
||||||
error_regex="Dimension out of range")
|
error_regex="Dimension out of range")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user