[dynamic shapes] unbacked-safe should_swap (#160473)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160473
Approved by: https://github.com/laithsakka
This commit is contained in:
Pian Pawakapan 2025-09-11 18:51:22 +00:00 committed by PyTorch MergeBot
parent 9cac1b9259
commit ac72f81c12
7 changed files with 114 additions and 38 deletions

View File

@ -9604,6 +9604,69 @@ def ___make_guard_fn():
f(torch.randn(9, requires_grad=True), torch.tensor([3, 6])) f(torch.randn(9, requires_grad=True), torch.tensor([3, 6]))
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_dim_order(self):
@torch.compile(dynamic=False, fullgraph=True, backend="eager")
def f(x):
x = x.permute(3, 0, 2, 1)
return x, x.dim_order()
@torch.compile(dynamic=False, fullgraph=True, backend="eager")
def g(x):
return x.dim_order()
@torch.compile(dynamic=False, fullgraph=True, backend="eager")
def h0(xs, ambiguity_check=False):
u0, u1, u2 = xs.tolist()
torch._check(u2 >= u0)
torch._check(u1 >= u0)
# stride ordering still isn't unique here, should raise
y = torch.empty_strided([4, 4, 4], [u0, u1, u2])
return y.dim_order(ambiguity_check=ambiguity_check)
@torch.compile(dynamic=False, fullgraph=True, backend="eager")
def h1(xs, ambiguity_check=False):
u0, u1, u2 = xs.tolist()
y = torch.empty_strided([4, 4, 4], [u0, u0, u0]) # no ordering
return y.dim_order(ambiguity_check=ambiguity_check)
# check that for functions permuting contiguous input, the original stride is recovered with dim_order.
def test(x):
stride_inp = tuple(x.stride())
f_out, f_order = f(x)
self.assertEqual(stride_inp, tuple(f_out.stride(i) for i in f_order))
# shape: [4, u0, 5, u1]
x0 = torch.randn(4, 1, 5, 2)
torch._dynamo.decorators.mark_unbacked(x0, 1)
torch._dynamo.decorators.mark_unbacked(x0, 3)
test(x0)
# shape: [u0, u1, u2, u3]
x1 = torch.randn(4, 1, 5, 2)
for i in range(x1.ndim):
torch._dynamo.decorators.mark_unbacked(x1, i)
test(x1)
# custom strides (all integers)
x2 = torch.randn(10000)
x2 = x2.as_strided([4, 4, 4, 4], [1, 2, 4, 8])
assert g(x2) == (3, 2, 1, 0)
# custom unbacked strides with no ordering: ambiguity check should raise
xs = torch.tensor([2, 3, 4])
h0(xs)
with self.assertRaisesRegex(
torch._dynamo.exc.TorchRuntimeError,
r"The tensor does not have unique dim order.",
):
h0(xs, ambiguity_check=True)
with self.assertRaisesRegex(
torch._dynamo.exc.TorchRuntimeError,
r"The tensor does not have unique dim order.",
):
h1(xs, ambiguity_check=True)
def test_str_format_assert1(self): def test_str_format_assert1(self):
@torch.compile(backend="eager", fullgraph=True) @torch.compile(backend="eager", fullgraph=True)
def fn(img): def fn(img):

View File

@ -579,7 +579,7 @@ def view_copy_dtype(
def _get_shape_permutation_like( def _get_shape_permutation_like(
self: torch.Tensor, self: torch.Tensor,
) -> tuple[utils.ShapeType, utils.StrideType]: ) -> tuple[utils.ShapeType, utils.StrideType]:
physical_layout = utils.compute_elementwise_output_logical_to_physical_perm(self) physical_layout, _ = utils.compute_elementwise_output_logical_to_physical_perm(self)
shape = [self.shape[l] for l in physical_layout] shape = [self.shape[l] for l in physical_layout]
permutation = [0] * len(shape) permutation = [0] * len(shape)

View File

@ -3465,7 +3465,7 @@ def meta_index_Tensor(self, indices):
# Note that perm here is the reverse of the 'perm_' decided by # Note that perm here is the reverse of the 'perm_' decided by
# TensorIteratorBase::reorder_dimensions # TensorIteratorBase::reorder_dimensions
restrided_self = _restride_src(self) restrided_self = _restride_src(self)
perm = utils.compute_elementwise_output_logical_to_physical_perm(restrided_self) perm, _ = utils.compute_elementwise_output_logical_to_physical_perm(restrided_self)
# Follow TensorIteratorBase::allocate_or_resize_outputs # Follow TensorIteratorBase::allocate_or_resize_outputs
if list(perm) != list(range(len(perm))): if list(perm) != list(range(len(perm))):

View File

@ -404,7 +404,7 @@ def _prim_elementwise_meta(
utils.check_same_device(*args_, allow_cpu_scalar_tensors=True) utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
utils.check_same_shape(*args_, allow_cpu_scalar_tensors=True) utils.check_same_shape(*args_, allow_cpu_scalar_tensors=True)
l2p_perm = utils.compute_elementwise_output_logical_to_physical_perm(*args_) l2p_perm, _ = utils.compute_elementwise_output_logical_to_physical_perm(*args_)
shape = utils.extract_shape(*args_, allow_cpu_scalar_tensors=True) shape = utils.extract_shape(*args_, allow_cpu_scalar_tensors=True)
# Acquires the dtype # Acquires the dtype

View File

@ -534,12 +534,9 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool:
# This is also INCORRECT because it does not model TensorIterator's # This is also INCORRECT because it does not model TensorIterator's
# short-circuit, which can cause different strides. # short-circuit, which can cause different strides.
def compute_elementwise_output_logical_to_physical_perm( def compute_elementwise_output_logical_to_physical_perm(
*tensors, _skip_checks=False *tensors, _skip_checks=False, ambiguity_check=False
) -> list[int]: ) -> tuple[list[int], bool]:
from torch.fx.experimental.symbolic_shapes import ( from torch.fx.experimental.symbolic_shapes import guard_or_false
guard_or_false,
guard_size_oblivious,
)
if not _skip_checks and len(tensors) == 0: if not _skip_checks and len(tensors) == 0:
msg = "Can't compute elementwise output strides for zero tensors!" msg = "Can't compute elementwise output strides for zero tensors!"
@ -558,15 +555,15 @@ def compute_elementwise_output_logical_to_physical_perm(
# Short-circuits for CPU scalar case # Short-circuits for CPU scalar case
if len(tensors) == 0: if len(tensors) == 0:
return [] return [], False
# Short-circuits for shapes with zero or one dimensions # Short-circuits for shapes with zero or one dimensions
# TODO: are these necessary? # TODO: are these necessary?
ndim = tensors[0].ndim ndim = tensors[0].ndim
if ndim == 0: if ndim == 0:
return [] return [], False
if ndim == 1: if ndim == 1:
return [0] return [0], False
# Short-circuits if contiguous or channels last, following the fake fast path. # Short-circuits if contiguous or channels last, following the fake fast path.
# This reduces the number of guards we end up making # This reduces the number of guards we end up making
@ -584,42 +581,40 @@ def compute_elementwise_output_logical_to_physical_perm(
) )
if is_contiguous and not is_channels_last: if is_contiguous and not is_channels_last:
return list(range(ndim)) return list(range(ndim)), False
if is_channels_last and not is_contiguous: if is_channels_last and not is_contiguous:
return [0, *list(range(2, ndim)), 1] return [0, *list(range(2, ndim)), 1], False
shape = tensors[0].shape shape = tensors[0].shape
def should_swap(idx_a, idx_b): def should_swap(idx_a, idx_b):
def ge(a, b):
"""
Returns true if a is symbolically greater than or equal to b, assuming a >= 0, b >= 0.
"""
if guard_or_false(b == 0):
return True
elif guard_or_false(a == 0):
return False
return guard_or_false(a >= b) or guard_or_false(a % b == 0)
for tensor in tensors: for tensor in tensors:
stride_a = tensor.stride()[idx_a] stride_a = tensor.stride()[idx_a]
stride_b = tensor.stride()[idx_b] stride_b = tensor.stride()[idx_b]
if guard_size_oblivious(stride_a == 0) or guard_size_oblivious(
stride_b == 0 if guard_or_false(stride_a == 0) or guard_or_false(stride_b == 0):
):
continue continue
if guard_or_false(stride_a == stride_b): if guard_or_false(stride_a == stride_b):
if guard_size_oblivious(shape[idx_a] > shape[idx_b]): if ge(shape[idx_b], shape[idx_a]):
continue
return 1 return 1
# when stride_a = 1, we want stride_a < stride_b to be TRUE if ge(stride_b, stride_a):
# when stride_b = 1, we want stride_a < stride_b to be FALSE
elif guard_or_false(stride_a == 1):
return -1 return -1
elif guard_or_false(stride_b == 1): if ge(stride_a, stride_b):
return 1
if guard_size_oblivious(stride_a < stride_b):
return -1
if guard_size_oblivious(stride_a > stride_b):
return 1
# stride_a == stride_b
if guard_size_oblivious(shape[idx_a] > shape[idx_b]):
return 1 return 1
# Note: this case is hit if all strides are zero, # Note: this case is hit if all strides are zero,
@ -644,7 +639,16 @@ def compute_elementwise_output_logical_to_physical_perm(
elif comparison < 0: elif comparison < 0:
break break
return list(reversed(perm)) # verify we've imposed ordering if ambiguity_check=True
raise_ambiguous = False
if ambiguity_check:
for i, j in zip(range(ndim - 1), range(1, ndim)):
order = should_swap(perm[i], perm[j])
if order != -1:
raise_ambiguous = True
break
return list(reversed(perm)), raise_ambiguous
def compute_elementwise_output_strides(*tensors) -> tuple[int, ...]: def compute_elementwise_output_strides(*tensors) -> tuple[int, ...]:
@ -674,7 +678,7 @@ def compute_elementwise_output_strides(*tensors) -> tuple[int, ...]:
if ndim == 1: if ndim == 1:
return (1,) return (1,)
logical_to_physical_perm = compute_elementwise_output_logical_to_physical_perm( logical_to_physical_perm, _ = compute_elementwise_output_logical_to_physical_perm(
*tensors, _skip_checks=True *tensors, _skip_checks=True
) )
permuted_shape = apply_perm(shape, logical_to_physical_perm) # to physical permuted_shape = apply_perm(shape, logical_to_physical_perm) # to physical

View File

@ -5111,7 +5111,7 @@ def empty_like(
) )
# memory_format == torch.preserve_format # memory_format == torch.preserve_format
logical_to_physical_perm = ( logical_to_physical_perm, _ = (
utils.compute_elementwise_output_logical_to_physical_perm(a) utils.compute_elementwise_output_logical_to_physical_perm(a)
) )
# identity perm is [2, 1, 0] # identity perm is [2, 1, 0]

View File

@ -1585,17 +1585,19 @@ class Tensor(torch._C.TensorBase):
If any two dimensions have the same stride, swapping these dimensions won't If any two dimensions have the same stride, swapping these dimensions won't
change how data is accessed, leading to multiple correct dimension orders. change how data is accessed, leading to multiple correct dimension orders.
""" """
from torch.fx.experimental.symbolic_shapes import guard_or_false
sizes = tensor.size() sizes = tensor.size()
strides = tensor.stride() strides = tensor.stride()
# Check if there are any duplicate strides # Check if there are any duplicate strides
has_duplicate_strides = any( has_duplicate_strides = any(
earlier == later for earlier, later in zip(strides, strides[1:]) guard_or_false(earlier == later)
for earlier, later in zip(strides, strides[1:])
) )
# Check if there are any singleton dimensions # Check if there are any singleton dimensions
has_singleton_dims = any(size == 1 for size in sizes) has_singleton_dims = any(guard_or_false(size == 1) for size in sizes)
return has_duplicate_strides or has_singleton_dims return has_duplicate_strides or has_singleton_dims
@ -1615,7 +1617,14 @@ class Tensor(torch._C.TensorBase):
import torch._prims_common as utils import torch._prims_common as utils
return tuple(utils.compute_elementwise_output_logical_to_physical_perm(self)) out_perm, raise_ambiguity = (
utils.compute_elementwise_output_logical_to_physical_perm(
self, ambiguity_check=ambiguity_check
)
)
if raise_ambiguity:
raise RuntimeError("The tensor does not have unique dim order.")
return tuple(out_perm)
def _update_names(self, names, inplace): def _update_names(self, names, inplace):
if has_torch_function_unary(self): if has_torch_function_unary(self):