mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamic shapes] unbacked-safe should_swap (#160473)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/160473 Approved by: https://github.com/laithsakka
This commit is contained in:
parent
9cac1b9259
commit
ac72f81c12
|
|
@ -9604,6 +9604,69 @@ def ___make_guard_fn():
|
|||
|
||||
f(torch.randn(9, requires_grad=True), torch.tensor([3, 6]))
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_dim_order(self):
|
||||
@torch.compile(dynamic=False, fullgraph=True, backend="eager")
|
||||
def f(x):
|
||||
x = x.permute(3, 0, 2, 1)
|
||||
return x, x.dim_order()
|
||||
|
||||
@torch.compile(dynamic=False, fullgraph=True, backend="eager")
|
||||
def g(x):
|
||||
return x.dim_order()
|
||||
|
||||
@torch.compile(dynamic=False, fullgraph=True, backend="eager")
|
||||
def h0(xs, ambiguity_check=False):
|
||||
u0, u1, u2 = xs.tolist()
|
||||
torch._check(u2 >= u0)
|
||||
torch._check(u1 >= u0)
|
||||
# stride ordering still isn't unique here, should raise
|
||||
y = torch.empty_strided([4, 4, 4], [u0, u1, u2])
|
||||
return y.dim_order(ambiguity_check=ambiguity_check)
|
||||
|
||||
@torch.compile(dynamic=False, fullgraph=True, backend="eager")
|
||||
def h1(xs, ambiguity_check=False):
|
||||
u0, u1, u2 = xs.tolist()
|
||||
y = torch.empty_strided([4, 4, 4], [u0, u0, u0]) # no ordering
|
||||
return y.dim_order(ambiguity_check=ambiguity_check)
|
||||
|
||||
# check that for functions permuting contiguous input, the original stride is recovered with dim_order.
|
||||
def test(x):
|
||||
stride_inp = tuple(x.stride())
|
||||
f_out, f_order = f(x)
|
||||
self.assertEqual(stride_inp, tuple(f_out.stride(i) for i in f_order))
|
||||
|
||||
# shape: [4, u0, 5, u1]
|
||||
x0 = torch.randn(4, 1, 5, 2)
|
||||
torch._dynamo.decorators.mark_unbacked(x0, 1)
|
||||
torch._dynamo.decorators.mark_unbacked(x0, 3)
|
||||
test(x0)
|
||||
|
||||
# shape: [u0, u1, u2, u3]
|
||||
x1 = torch.randn(4, 1, 5, 2)
|
||||
for i in range(x1.ndim):
|
||||
torch._dynamo.decorators.mark_unbacked(x1, i)
|
||||
test(x1)
|
||||
|
||||
# custom strides (all integers)
|
||||
x2 = torch.randn(10000)
|
||||
x2 = x2.as_strided([4, 4, 4, 4], [1, 2, 4, 8])
|
||||
assert g(x2) == (3, 2, 1, 0)
|
||||
|
||||
# custom unbacked strides with no ordering: ambiguity check should raise
|
||||
xs = torch.tensor([2, 3, 4])
|
||||
h0(xs)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.TorchRuntimeError,
|
||||
r"The tensor does not have unique dim order.",
|
||||
):
|
||||
h0(xs, ambiguity_check=True)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.TorchRuntimeError,
|
||||
r"The tensor does not have unique dim order.",
|
||||
):
|
||||
h1(xs, ambiguity_check=True)
|
||||
|
||||
def test_str_format_assert1(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(img):
|
||||
|
|
|
|||
|
|
@ -579,7 +579,7 @@ def view_copy_dtype(
|
|||
def _get_shape_permutation_like(
|
||||
self: torch.Tensor,
|
||||
) -> tuple[utils.ShapeType, utils.StrideType]:
|
||||
physical_layout = utils.compute_elementwise_output_logical_to_physical_perm(self)
|
||||
physical_layout, _ = utils.compute_elementwise_output_logical_to_physical_perm(self)
|
||||
shape = [self.shape[l] for l in physical_layout]
|
||||
|
||||
permutation = [0] * len(shape)
|
||||
|
|
|
|||
|
|
@ -3465,7 +3465,7 @@ def meta_index_Tensor(self, indices):
|
|||
# Note that perm here is the reverse of the 'perm_' decided by
|
||||
# TensorIteratorBase::reorder_dimensions
|
||||
restrided_self = _restride_src(self)
|
||||
perm = utils.compute_elementwise_output_logical_to_physical_perm(restrided_self)
|
||||
perm, _ = utils.compute_elementwise_output_logical_to_physical_perm(restrided_self)
|
||||
|
||||
# Follow TensorIteratorBase::allocate_or_resize_outputs
|
||||
if list(perm) != list(range(len(perm))):
|
||||
|
|
|
|||
|
|
@ -404,7 +404,7 @@ def _prim_elementwise_meta(
|
|||
utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
|
||||
utils.check_same_shape(*args_, allow_cpu_scalar_tensors=True)
|
||||
|
||||
l2p_perm = utils.compute_elementwise_output_logical_to_physical_perm(*args_)
|
||||
l2p_perm, _ = utils.compute_elementwise_output_logical_to_physical_perm(*args_)
|
||||
shape = utils.extract_shape(*args_, allow_cpu_scalar_tensors=True)
|
||||
|
||||
# Acquires the dtype
|
||||
|
|
|
|||
|
|
@ -534,12 +534,9 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool:
|
|||
# This is also INCORRECT because it does not model TensorIterator's
|
||||
# short-circuit, which can cause different strides.
|
||||
def compute_elementwise_output_logical_to_physical_perm(
|
||||
*tensors, _skip_checks=False
|
||||
) -> list[int]:
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
guard_or_false,
|
||||
guard_size_oblivious,
|
||||
)
|
||||
*tensors, _skip_checks=False, ambiguity_check=False
|
||||
) -> tuple[list[int], bool]:
|
||||
from torch.fx.experimental.symbolic_shapes import guard_or_false
|
||||
|
||||
if not _skip_checks and len(tensors) == 0:
|
||||
msg = "Can't compute elementwise output strides for zero tensors!"
|
||||
|
|
@ -558,15 +555,15 @@ def compute_elementwise_output_logical_to_physical_perm(
|
|||
|
||||
# Short-circuits for CPU scalar case
|
||||
if len(tensors) == 0:
|
||||
return []
|
||||
return [], False
|
||||
|
||||
# Short-circuits for shapes with zero or one dimensions
|
||||
# TODO: are these necessary?
|
||||
ndim = tensors[0].ndim
|
||||
if ndim == 0:
|
||||
return []
|
||||
return [], False
|
||||
if ndim == 1:
|
||||
return [0]
|
||||
return [0], False
|
||||
|
||||
# Short-circuits if contiguous or channels last, following the fake fast path.
|
||||
# This reduces the number of guards we end up making
|
||||
|
|
@ -584,42 +581,40 @@ def compute_elementwise_output_logical_to_physical_perm(
|
|||
)
|
||||
|
||||
if is_contiguous and not is_channels_last:
|
||||
return list(range(ndim))
|
||||
return list(range(ndim)), False
|
||||
|
||||
if is_channels_last and not is_contiguous:
|
||||
return [0, *list(range(2, ndim)), 1]
|
||||
return [0, *list(range(2, ndim)), 1], False
|
||||
|
||||
shape = tensors[0].shape
|
||||
|
||||
def should_swap(idx_a, idx_b):
|
||||
def ge(a, b):
|
||||
"""
|
||||
Returns true if a is symbolically greater than or equal to b, assuming a >= 0, b >= 0.
|
||||
"""
|
||||
if guard_or_false(b == 0):
|
||||
return True
|
||||
elif guard_or_false(a == 0):
|
||||
return False
|
||||
return guard_or_false(a >= b) or guard_or_false(a % b == 0)
|
||||
|
||||
for tensor in tensors:
|
||||
stride_a = tensor.stride()[idx_a]
|
||||
stride_b = tensor.stride()[idx_b]
|
||||
if guard_size_oblivious(stride_a == 0) or guard_size_oblivious(
|
||||
stride_b == 0
|
||||
):
|
||||
|
||||
if guard_or_false(stride_a == 0) or guard_or_false(stride_b == 0):
|
||||
continue
|
||||
|
||||
if guard_or_false(stride_a == stride_b):
|
||||
if guard_size_oblivious(shape[idx_a] > shape[idx_b]):
|
||||
return 1
|
||||
|
||||
# when stride_a = 1, we want stride_a < stride_b to be TRUE
|
||||
# when stride_b = 1, we want stride_a < stride_b to be FALSE
|
||||
elif guard_or_false(stride_a == 1):
|
||||
return -1
|
||||
|
||||
elif guard_or_false(stride_b == 1):
|
||||
if ge(shape[idx_b], shape[idx_a]):
|
||||
continue
|
||||
return 1
|
||||
|
||||
if guard_size_oblivious(stride_a < stride_b):
|
||||
if ge(stride_b, stride_a):
|
||||
return -1
|
||||
|
||||
if guard_size_oblivious(stride_a > stride_b):
|
||||
return 1
|
||||
|
||||
# stride_a == stride_b
|
||||
if guard_size_oblivious(shape[idx_a] > shape[idx_b]):
|
||||
if ge(stride_a, stride_b):
|
||||
return 1
|
||||
|
||||
# Note: this case is hit if all strides are zero,
|
||||
|
|
@ -644,7 +639,16 @@ def compute_elementwise_output_logical_to_physical_perm(
|
|||
elif comparison < 0:
|
||||
break
|
||||
|
||||
return list(reversed(perm))
|
||||
# verify we've imposed ordering if ambiguity_check=True
|
||||
raise_ambiguous = False
|
||||
if ambiguity_check:
|
||||
for i, j in zip(range(ndim - 1), range(1, ndim)):
|
||||
order = should_swap(perm[i], perm[j])
|
||||
if order != -1:
|
||||
raise_ambiguous = True
|
||||
break
|
||||
|
||||
return list(reversed(perm)), raise_ambiguous
|
||||
|
||||
|
||||
def compute_elementwise_output_strides(*tensors) -> tuple[int, ...]:
|
||||
|
|
@ -674,7 +678,7 @@ def compute_elementwise_output_strides(*tensors) -> tuple[int, ...]:
|
|||
if ndim == 1:
|
||||
return (1,)
|
||||
|
||||
logical_to_physical_perm = compute_elementwise_output_logical_to_physical_perm(
|
||||
logical_to_physical_perm, _ = compute_elementwise_output_logical_to_physical_perm(
|
||||
*tensors, _skip_checks=True
|
||||
)
|
||||
permuted_shape = apply_perm(shape, logical_to_physical_perm) # to physical
|
||||
|
|
|
|||
|
|
@ -5111,7 +5111,7 @@ def empty_like(
|
|||
)
|
||||
|
||||
# memory_format == torch.preserve_format
|
||||
logical_to_physical_perm = (
|
||||
logical_to_physical_perm, _ = (
|
||||
utils.compute_elementwise_output_logical_to_physical_perm(a)
|
||||
)
|
||||
# identity perm is [2, 1, 0]
|
||||
|
|
|
|||
|
|
@ -1585,17 +1585,19 @@ class Tensor(torch._C.TensorBase):
|
|||
If any two dimensions have the same stride, swapping these dimensions won't
|
||||
change how data is accessed, leading to multiple correct dimension orders.
|
||||
"""
|
||||
from torch.fx.experimental.symbolic_shapes import guard_or_false
|
||||
|
||||
sizes = tensor.size()
|
||||
strides = tensor.stride()
|
||||
|
||||
# Check if there are any duplicate strides
|
||||
has_duplicate_strides = any(
|
||||
earlier == later for earlier, later in zip(strides, strides[1:])
|
||||
guard_or_false(earlier == later)
|
||||
for earlier, later in zip(strides, strides[1:])
|
||||
)
|
||||
|
||||
# Check if there are any singleton dimensions
|
||||
has_singleton_dims = any(size == 1 for size in sizes)
|
||||
has_singleton_dims = any(guard_or_false(size == 1) for size in sizes)
|
||||
|
||||
return has_duplicate_strides or has_singleton_dims
|
||||
|
||||
|
|
@ -1615,7 +1617,14 @@ class Tensor(torch._C.TensorBase):
|
|||
|
||||
import torch._prims_common as utils
|
||||
|
||||
return tuple(utils.compute_elementwise_output_logical_to_physical_perm(self))
|
||||
out_perm, raise_ambiguity = (
|
||||
utils.compute_elementwise_output_logical_to_physical_perm(
|
||||
self, ambiguity_check=ambiguity_check
|
||||
)
|
||||
)
|
||||
if raise_ambiguity:
|
||||
raise RuntimeError("The tensor does not have unique dim order.")
|
||||
return tuple(out_perm)
|
||||
|
||||
def _update_names(self, names, inplace):
|
||||
if has_torch_function_unary(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user