[dynamic shapes] unbacked-safe should_swap (#160473)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160473
Approved by: https://github.com/laithsakka
This commit is contained in:
Pian Pawakapan 2025-09-11 18:51:22 +00:00 committed by PyTorch MergeBot
parent 9cac1b9259
commit ac72f81c12
7 changed files with 114 additions and 38 deletions

View File

@ -9604,6 +9604,69 @@ def ___make_guard_fn():
f(torch.randn(9, requires_grad=True), torch.tensor([3, 6]))
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_dim_order(self):
@torch.compile(dynamic=False, fullgraph=True, backend="eager")
def f(x):
x = x.permute(3, 0, 2, 1)
return x, x.dim_order()
@torch.compile(dynamic=False, fullgraph=True, backend="eager")
def g(x):
return x.dim_order()
@torch.compile(dynamic=False, fullgraph=True, backend="eager")
def h0(xs, ambiguity_check=False):
u0, u1, u2 = xs.tolist()
torch._check(u2 >= u0)
torch._check(u1 >= u0)
# stride ordering still isn't unique here, should raise
y = torch.empty_strided([4, 4, 4], [u0, u1, u2])
return y.dim_order(ambiguity_check=ambiguity_check)
@torch.compile(dynamic=False, fullgraph=True, backend="eager")
def h1(xs, ambiguity_check=False):
u0, u1, u2 = xs.tolist()
y = torch.empty_strided([4, 4, 4], [u0, u0, u0]) # no ordering
return y.dim_order(ambiguity_check=ambiguity_check)
# check that for functions permuting contiguous input, the original stride is recovered with dim_order.
def test(x):
stride_inp = tuple(x.stride())
f_out, f_order = f(x)
self.assertEqual(stride_inp, tuple(f_out.stride(i) for i in f_order))
# shape: [4, u0, 5, u1]
x0 = torch.randn(4, 1, 5, 2)
torch._dynamo.decorators.mark_unbacked(x0, 1)
torch._dynamo.decorators.mark_unbacked(x0, 3)
test(x0)
# shape: [u0, u1, u2, u3]
x1 = torch.randn(4, 1, 5, 2)
for i in range(x1.ndim):
torch._dynamo.decorators.mark_unbacked(x1, i)
test(x1)
# custom strides (all integers)
x2 = torch.randn(10000)
x2 = x2.as_strided([4, 4, 4, 4], [1, 2, 4, 8])
assert g(x2) == (3, 2, 1, 0)
# custom unbacked strides with no ordering: ambiguity check should raise
xs = torch.tensor([2, 3, 4])
h0(xs)
with self.assertRaisesRegex(
torch._dynamo.exc.TorchRuntimeError,
r"The tensor does not have unique dim order.",
):
h0(xs, ambiguity_check=True)
with self.assertRaisesRegex(
torch._dynamo.exc.TorchRuntimeError,
r"The tensor does not have unique dim order.",
):
h1(xs, ambiguity_check=True)
def test_str_format_assert1(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(img):

View File

@ -579,7 +579,7 @@ def view_copy_dtype(
def _get_shape_permutation_like(
self: torch.Tensor,
) -> tuple[utils.ShapeType, utils.StrideType]:
physical_layout = utils.compute_elementwise_output_logical_to_physical_perm(self)
physical_layout, _ = utils.compute_elementwise_output_logical_to_physical_perm(self)
shape = [self.shape[l] for l in physical_layout]
permutation = [0] * len(shape)

View File

@ -3465,7 +3465,7 @@ def meta_index_Tensor(self, indices):
# Note that perm here is the reverse of the 'perm_' decided by
# TensorIteratorBase::reorder_dimensions
restrided_self = _restride_src(self)
perm = utils.compute_elementwise_output_logical_to_physical_perm(restrided_self)
perm, _ = utils.compute_elementwise_output_logical_to_physical_perm(restrided_self)
# Follow TensorIteratorBase::allocate_or_resize_outputs
if list(perm) != list(range(len(perm))):

View File

@ -404,7 +404,7 @@ def _prim_elementwise_meta(
utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
utils.check_same_shape(*args_, allow_cpu_scalar_tensors=True)
l2p_perm = utils.compute_elementwise_output_logical_to_physical_perm(*args_)
l2p_perm, _ = utils.compute_elementwise_output_logical_to_physical_perm(*args_)
shape = utils.extract_shape(*args_, allow_cpu_scalar_tensors=True)
# Acquires the dtype

View File

@ -534,12 +534,9 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool:
# This is also INCORRECT because it does not model TensorIterator's
# short-circuit, which can cause different strides.
def compute_elementwise_output_logical_to_physical_perm(
*tensors, _skip_checks=False
) -> list[int]:
from torch.fx.experimental.symbolic_shapes import (
guard_or_false,
guard_size_oblivious,
)
*tensors, _skip_checks=False, ambiguity_check=False
) -> tuple[list[int], bool]:
from torch.fx.experimental.symbolic_shapes import guard_or_false
if not _skip_checks and len(tensors) == 0:
msg = "Can't compute elementwise output strides for zero tensors!"
@ -558,15 +555,15 @@ def compute_elementwise_output_logical_to_physical_perm(
# Short-circuits for CPU scalar case
if len(tensors) == 0:
return []
return [], False
# Short-circuits for shapes with zero or one dimensions
# TODO: are these necessary?
ndim = tensors[0].ndim
if ndim == 0:
return []
return [], False
if ndim == 1:
return [0]
return [0], False
# Short-circuits if contiguous or channels last, following the fake fast path.
# This reduces the number of guards we end up making
@ -584,42 +581,40 @@ def compute_elementwise_output_logical_to_physical_perm(
)
if is_contiguous and not is_channels_last:
return list(range(ndim))
return list(range(ndim)), False
if is_channels_last and not is_contiguous:
return [0, *list(range(2, ndim)), 1]
return [0, *list(range(2, ndim)), 1], False
shape = tensors[0].shape
def should_swap(idx_a, idx_b):
def ge(a, b):
"""
Returns true if a is symbolically greater than or equal to b, assuming a >= 0, b >= 0.
"""
if guard_or_false(b == 0):
return True
elif guard_or_false(a == 0):
return False
return guard_or_false(a >= b) or guard_or_false(a % b == 0)
for tensor in tensors:
stride_a = tensor.stride()[idx_a]
stride_b = tensor.stride()[idx_b]
if guard_size_oblivious(stride_a == 0) or guard_size_oblivious(
stride_b == 0
):
if guard_or_false(stride_a == 0) or guard_or_false(stride_b == 0):
continue
if guard_or_false(stride_a == stride_b):
if guard_size_oblivious(shape[idx_a] > shape[idx_b]):
return 1
# when stride_a = 1, we want stride_a < stride_b to be TRUE
# when stride_b = 1, we want stride_a < stride_b to be FALSE
elif guard_or_false(stride_a == 1):
return -1
elif guard_or_false(stride_b == 1):
if ge(shape[idx_b], shape[idx_a]):
continue
return 1
if guard_size_oblivious(stride_a < stride_b):
if ge(stride_b, stride_a):
return -1
if guard_size_oblivious(stride_a > stride_b):
return 1
# stride_a == stride_b
if guard_size_oblivious(shape[idx_a] > shape[idx_b]):
if ge(stride_a, stride_b):
return 1
# Note: this case is hit if all strides are zero,
@ -644,7 +639,16 @@ def compute_elementwise_output_logical_to_physical_perm(
elif comparison < 0:
break
return list(reversed(perm))
# verify we've imposed ordering if ambiguity_check=True
raise_ambiguous = False
if ambiguity_check:
for i, j in zip(range(ndim - 1), range(1, ndim)):
order = should_swap(perm[i], perm[j])
if order != -1:
raise_ambiguous = True
break
return list(reversed(perm)), raise_ambiguous
def compute_elementwise_output_strides(*tensors) -> tuple[int, ...]:
@ -674,7 +678,7 @@ def compute_elementwise_output_strides(*tensors) -> tuple[int, ...]:
if ndim == 1:
return (1,)
logical_to_physical_perm = compute_elementwise_output_logical_to_physical_perm(
logical_to_physical_perm, _ = compute_elementwise_output_logical_to_physical_perm(
*tensors, _skip_checks=True
)
permuted_shape = apply_perm(shape, logical_to_physical_perm) # to physical

View File

@ -5111,7 +5111,7 @@ def empty_like(
)
# memory_format == torch.preserve_format
logical_to_physical_perm = (
logical_to_physical_perm, _ = (
utils.compute_elementwise_output_logical_to_physical_perm(a)
)
# identity perm is [2, 1, 0]

View File

@ -1585,17 +1585,19 @@ class Tensor(torch._C.TensorBase):
If any two dimensions have the same stride, swapping these dimensions won't
change how data is accessed, leading to multiple correct dimension orders.
"""
from torch.fx.experimental.symbolic_shapes import guard_or_false
sizes = tensor.size()
strides = tensor.stride()
# Check if there are any duplicate strides
has_duplicate_strides = any(
earlier == later for earlier, later in zip(strides, strides[1:])
guard_or_false(earlier == later)
for earlier, later in zip(strides, strides[1:])
)
# Check if there are any singleton dimensions
has_singleton_dims = any(size == 1 for size in sizes)
has_singleton_dims = any(guard_or_false(size == 1) for size in sizes)
return has_duplicate_strides or has_singleton_dims
@ -1615,7 +1617,14 @@ class Tensor(torch._C.TensorBase):
import torch._prims_common as utils
return tuple(utils.compute_elementwise_output_logical_to_physical_perm(self))
out_perm, raise_ambiguity = (
utils.compute_elementwise_output_logical_to_physical_perm(
self, ambiguity_check=ambiguity_check
)
)
if raise_ambiguity:
raise RuntimeError("The tensor does not have unique dim order.")
return tuple(out_perm)
def _update_names(self, names, inplace):
if has_torch_function_unary(self):