[primTorch] Fix off by 1 in canonicalize_dim (#83198)

Also fix an issue in the `unsqueeze` ref due to this change.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83198
Approved by: https://github.com/ngimel
This commit is contained in:
Nikita Karetnikov 2022-08-16 01:15:03 +02:00 committed by PyTorch MergeBot
parent 6a5ca409da
commit 4010f96121
6 changed files with 37 additions and 20 deletions

View File

@ -7,10 +7,13 @@ int64_t maybe_wrap_dim_slow(
int64_t dim,
int64_t dim_post_expr,
bool wrap_scalar) {
if (dim_post_expr <= 0) {
TORCH_CHECK_INDEX(
dim_post_expr >= 0, "Rank cannot be negative but got ", dim_post_expr);
if (dim_post_expr == 0) {
TORCH_CHECK_INDEX(
wrap_scalar,
"dimension specified as ",
"Dimension specified as ",
dim,
" but tensor has no dimensions");
return c10::maybe_wrap_dim(dim, /*dim_post_expr=*/1, /*wrap_scalar=*/false);

View File

@ -5888,7 +5888,7 @@ class TestTorch(TestCase):
torch.tensor([1]).unflatten(0, [])
with self.assertRaisesRegex(RuntimeError, r"Provided sizes \[2, 2\] don't multiply up to the size of dim 0 \(1\)"):
torch.tensor([1]).unflatten(0, [2, 2])
with self.assertRaisesRegex(IndexError, r"dimension specified as 0 but tensor has no dimensions"):
with self.assertRaisesRegex(IndexError, r"Dimension specified as 0 but tensor has no dimensions"):
torch.tensor(1).unflatten(0, [0])
with self.assertRaisesRegex(RuntimeError, r"only one dimension can be inferred"):
torch.randn(5, 10).unflatten(1, (-1, -1))

View File

@ -1384,12 +1384,18 @@ conj = _make_prim(
)
def expand_dims(a: TensorLikeType, dimensions: DimsSequenceType) -> TensorLikeType:
def expand_dims(
a: TensorLikeType, dimensions: DimsSequenceType, ndim=None
) -> TensorLikeType:
"""
Creates a view of a with a.ndim + len(dimensions) dimensions, with new
dimensions of length one at the dimensions specified by dimensions.
"""
dims = sorted(utils.canonicalize_dims(a.ndim, dimensions)) # type: ignore[arg-type]
if ndim is not None:
# TODO: this is only here to support the unsqueeze ref
dims = sorted(utils.canonicalize_dims(ndim, dimensions)) # type: ignore[arg-type]
else:
dims = sorted(utils.canonicalize_dims(a.ndim, dimensions)) # type: ignore[arg-type]
if len(set(dims)) != len(dims):
msg = "Received duplicate dimensions to expand in {0}".format(str(dimensions))
raise ValueError(msg)

View File

@ -442,19 +442,26 @@ def validate_exclusive_idx(rank: int, ex_idx: int):
# "Wraps" a dim (up to one time) for the given rank, allowing
# dims to be specified using negative indices
def canonicalize_dim(rank: int, idx: int) -> int:
# TODO: add a comment for why this is
_rank = rank if rank != 0 else 1
def canonicalize_dim(rank: int, idx: int, wrap_scalar: bool = True) -> int:
if rank < 0:
msg = f"Rank cannot be negative but got {rank}"
raise IndexError(msg)
if idx >= 0 and idx < _rank:
if rank == 0:
if not wrap_scalar:
msg = f"Dimension specified as {idx} but tensor has no dimensions"
raise IndexError(msg)
rank = 1
if idx >= 0 and idx < rank:
return idx
if idx < 0:
_idx = idx + _rank
_idx = idx + rank
else:
_idx = idx
if _idx < 0 or _idx > _rank:
if _idx < 0 or _idx >= rank:
# Same error message as in aten/src/ATen/WrapDimUtils.h:49
msg = "Dimension out of range (expected to be in range of [{0}, {1}], but got {2})".format(
-rank, rank - 1, idx

View File

@ -2172,8 +2172,7 @@ def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType:
@out_wrapper()
def column_stack(tensors: TensorSequenceType) -> TensorLikeType:
aligned_tensors = tuple(
x if x.ndim > 1 else prims.expand_dims(x, list(range(x.ndim, 2)))
for x in tensors
x if x.ndim > 1 else x.reshape((x.numel(), 1)) for x in tensors
)
return cat(aligned_tensors, 1)
@ -2684,16 +2683,17 @@ def rot90(
a: TensorLikeType, k: int = 1, dims: DimsSequenceType = (0, 1)
) -> TensorLikeType:
"""Reference implementation of :func:`torch.rot90`."""
dims_ = utils.canonicalize_dims(a.ndim, dims)
# Required to silence MyPy errors
assert isinstance(dims_, (tuple, list))
dims = dims_
if len(dims) != 2:
raise RuntimeError(
f"expected total rotation dims == 2, but got dims = {len(dims)}"
)
if a.ndim < 2:
raise RuntimeError(f"expected total dims >= 2, but got total dims = {a.ndim}")
# Do this after the initial checks to be compatible with the behavior in
# core.
dims = utils.canonicalize_dims(a.ndim, dims)
if dims[0] == dims[1]:
raise RuntimeError(
f"expected rotation dims to be different, but got dim0 = {dims[0]} and dim1 = {dims[1]}"
@ -3029,8 +3029,9 @@ swap_axes = transpose
def unsqueeze(a: TensorLikeType, dim: int) -> TensorLikeType:
# Note that unsqueeze canonicalizes with rank + 1 because it allows
# a new innermost dimension to be specified
dim = utils.canonicalize_dim(a.ndim + 1, dim)
return prims.expand_dims(a, (dim,))
ndim = a.ndim + 1
dim = utils.canonicalize_dim(ndim, dim)
return prims.expand_dims(a, (dim,), ndim=ndim)
# NOTE: shape is a vararg because Tensor.reshape can be called with as

View File

@ -2362,7 +2362,7 @@ def sample_inputs_unbind(op_info, device, dtype, requires_grad, **kwargs):
def error_inputs_unbind(op_info, device):
make_arg = partial(make_tensor, dtype=torch.int32, device=device, requires_grad=False)
yield ErrorInput(SampleInput(make_arg(()), args=(0,)), error_type=IndexError,
error_regex="dimension specified as 0 but tensor has no dimensions")
error_regex="Dimension specified as 0 but tensor has no dimensions")
yield ErrorInput(SampleInput(make_arg((2,)), args=(2,)), error_type=IndexError,
error_regex="Dimension out of range")