mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[primTorch] Fix off by 1 in canonicalize_dim (#83198)
Also fix an issue in the `unsqueeze` ref due to this change. Pull Request resolved: https://github.com/pytorch/pytorch/pull/83198 Approved by: https://github.com/ngimel
This commit is contained in:
parent
6a5ca409da
commit
4010f96121
|
|
@ -7,10 +7,13 @@ int64_t maybe_wrap_dim_slow(
|
|||
int64_t dim,
|
||||
int64_t dim_post_expr,
|
||||
bool wrap_scalar) {
|
||||
if (dim_post_expr <= 0) {
|
||||
TORCH_CHECK_INDEX(
|
||||
dim_post_expr >= 0, "Rank cannot be negative but got ", dim_post_expr);
|
||||
|
||||
if (dim_post_expr == 0) {
|
||||
TORCH_CHECK_INDEX(
|
||||
wrap_scalar,
|
||||
"dimension specified as ",
|
||||
"Dimension specified as ",
|
||||
dim,
|
||||
" but tensor has no dimensions");
|
||||
return c10::maybe_wrap_dim(dim, /*dim_post_expr=*/1, /*wrap_scalar=*/false);
|
||||
|
|
|
|||
|
|
@ -5888,7 +5888,7 @@ class TestTorch(TestCase):
|
|||
torch.tensor([1]).unflatten(0, [])
|
||||
with self.assertRaisesRegex(RuntimeError, r"Provided sizes \[2, 2\] don't multiply up to the size of dim 0 \(1\)"):
|
||||
torch.tensor([1]).unflatten(0, [2, 2])
|
||||
with self.assertRaisesRegex(IndexError, r"dimension specified as 0 but tensor has no dimensions"):
|
||||
with self.assertRaisesRegex(IndexError, r"Dimension specified as 0 but tensor has no dimensions"):
|
||||
torch.tensor(1).unflatten(0, [0])
|
||||
with self.assertRaisesRegex(RuntimeError, r"only one dimension can be inferred"):
|
||||
torch.randn(5, 10).unflatten(1, (-1, -1))
|
||||
|
|
|
|||
|
|
@ -1384,12 +1384,18 @@ conj = _make_prim(
|
|||
)
|
||||
|
||||
|
||||
def expand_dims(a: TensorLikeType, dimensions: DimsSequenceType) -> TensorLikeType:
|
||||
def expand_dims(
|
||||
a: TensorLikeType, dimensions: DimsSequenceType, ndim=None
|
||||
) -> TensorLikeType:
|
||||
"""
|
||||
Creates a view of a with a.ndim + len(dimensions) dimensions, with new
|
||||
dimensions of length one at the dimensions specified by dimensions.
|
||||
"""
|
||||
dims = sorted(utils.canonicalize_dims(a.ndim, dimensions)) # type: ignore[arg-type]
|
||||
if ndim is not None:
|
||||
# TODO: this is only here to support the unsqueeze ref
|
||||
dims = sorted(utils.canonicalize_dims(ndim, dimensions)) # type: ignore[arg-type]
|
||||
else:
|
||||
dims = sorted(utils.canonicalize_dims(a.ndim, dimensions)) # type: ignore[arg-type]
|
||||
if len(set(dims)) != len(dims):
|
||||
msg = "Received duplicate dimensions to expand in {0}".format(str(dimensions))
|
||||
raise ValueError(msg)
|
||||
|
|
|
|||
|
|
@ -442,19 +442,26 @@ def validate_exclusive_idx(rank: int, ex_idx: int):
|
|||
|
||||
# "Wraps" a dim (up to one time) for the given rank, allowing
|
||||
# dims to be specified using negative indices
|
||||
def canonicalize_dim(rank: int, idx: int) -> int:
|
||||
# TODO: add a comment for why this is
|
||||
_rank = rank if rank != 0 else 1
|
||||
def canonicalize_dim(rank: int, idx: int, wrap_scalar: bool = True) -> int:
|
||||
if rank < 0:
|
||||
msg = f"Rank cannot be negative but got {rank}"
|
||||
raise IndexError(msg)
|
||||
|
||||
if idx >= 0 and idx < _rank:
|
||||
if rank == 0:
|
||||
if not wrap_scalar:
|
||||
msg = f"Dimension specified as {idx} but tensor has no dimensions"
|
||||
raise IndexError(msg)
|
||||
rank = 1
|
||||
|
||||
if idx >= 0 and idx < rank:
|
||||
return idx
|
||||
|
||||
if idx < 0:
|
||||
_idx = idx + _rank
|
||||
_idx = idx + rank
|
||||
else:
|
||||
_idx = idx
|
||||
|
||||
if _idx < 0 or _idx > _rank:
|
||||
if _idx < 0 or _idx >= rank:
|
||||
# Same error message as in aten/src/ATen/WrapDimUtils.h:49
|
||||
msg = "Dimension out of range (expected to be in range of [{0}, {1}], but got {2})".format(
|
||||
-rank, rank - 1, idx
|
||||
|
|
|
|||
|
|
@ -2172,8 +2172,7 @@ def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType:
|
|||
@out_wrapper()
|
||||
def column_stack(tensors: TensorSequenceType) -> TensorLikeType:
|
||||
aligned_tensors = tuple(
|
||||
x if x.ndim > 1 else prims.expand_dims(x, list(range(x.ndim, 2)))
|
||||
for x in tensors
|
||||
x if x.ndim > 1 else x.reshape((x.numel(), 1)) for x in tensors
|
||||
)
|
||||
return cat(aligned_tensors, 1)
|
||||
|
||||
|
|
@ -2684,16 +2683,17 @@ def rot90(
|
|||
a: TensorLikeType, k: int = 1, dims: DimsSequenceType = (0, 1)
|
||||
) -> TensorLikeType:
|
||||
"""Reference implementation of :func:`torch.rot90`."""
|
||||
dims_ = utils.canonicalize_dims(a.ndim, dims)
|
||||
# Required to silence MyPy errors
|
||||
assert isinstance(dims_, (tuple, list))
|
||||
dims = dims_
|
||||
if len(dims) != 2:
|
||||
raise RuntimeError(
|
||||
f"expected total rotation dims == 2, but got dims = {len(dims)}"
|
||||
)
|
||||
if a.ndim < 2:
|
||||
raise RuntimeError(f"expected total dims >= 2, but got total dims = {a.ndim}")
|
||||
|
||||
# Do this after the initial checks to be compatible with the behavior in
|
||||
# core.
|
||||
dims = utils.canonicalize_dims(a.ndim, dims)
|
||||
|
||||
if dims[0] == dims[1]:
|
||||
raise RuntimeError(
|
||||
f"expected rotation dims to be different, but got dim0 = {dims[0]} and dim1 = {dims[1]}"
|
||||
|
|
@ -3029,8 +3029,9 @@ swap_axes = transpose
|
|||
def unsqueeze(a: TensorLikeType, dim: int) -> TensorLikeType:
|
||||
# Note that unsqueeze canonicalizes with rank + 1 because it allows
|
||||
# a new innermost dimension to be specified
|
||||
dim = utils.canonicalize_dim(a.ndim + 1, dim)
|
||||
return prims.expand_dims(a, (dim,))
|
||||
ndim = a.ndim + 1
|
||||
dim = utils.canonicalize_dim(ndim, dim)
|
||||
return prims.expand_dims(a, (dim,), ndim=ndim)
|
||||
|
||||
|
||||
# NOTE: shape is a vararg because Tensor.reshape can be called with as
|
||||
|
|
|
|||
|
|
@ -2362,7 +2362,7 @@ def sample_inputs_unbind(op_info, device, dtype, requires_grad, **kwargs):
|
|||
def error_inputs_unbind(op_info, device):
|
||||
make_arg = partial(make_tensor, dtype=torch.int32, device=device, requires_grad=False)
|
||||
yield ErrorInput(SampleInput(make_arg(()), args=(0,)), error_type=IndexError,
|
||||
error_regex="dimension specified as 0 but tensor has no dimensions")
|
||||
error_regex="Dimension specified as 0 but tensor has no dimensions")
|
||||
yield ErrorInput(SampleInput(make_arg((2,)), args=(2,)), error_type=IndexError,
|
||||
error_regex="Dimension out of range")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user