mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamic shapes] unbacked-safe should_swap (#160473)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/160473 Approved by: https://github.com/laithsakka
This commit is contained in:
parent
9cac1b9259
commit
ac72f81c12
|
|
@ -9604,6 +9604,69 @@ def ___make_guard_fn():
|
||||||
|
|
||||||
f(torch.randn(9, requires_grad=True), torch.tensor([3, 6]))
|
f(torch.randn(9, requires_grad=True), torch.tensor([3, 6]))
|
||||||
|
|
||||||
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||||
|
def test_dim_order(self):
|
||||||
|
@torch.compile(dynamic=False, fullgraph=True, backend="eager")
|
||||||
|
def f(x):
|
||||||
|
x = x.permute(3, 0, 2, 1)
|
||||||
|
return x, x.dim_order()
|
||||||
|
|
||||||
|
@torch.compile(dynamic=False, fullgraph=True, backend="eager")
|
||||||
|
def g(x):
|
||||||
|
return x.dim_order()
|
||||||
|
|
||||||
|
@torch.compile(dynamic=False, fullgraph=True, backend="eager")
|
||||||
|
def h0(xs, ambiguity_check=False):
|
||||||
|
u0, u1, u2 = xs.tolist()
|
||||||
|
torch._check(u2 >= u0)
|
||||||
|
torch._check(u1 >= u0)
|
||||||
|
# stride ordering still isn't unique here, should raise
|
||||||
|
y = torch.empty_strided([4, 4, 4], [u0, u1, u2])
|
||||||
|
return y.dim_order(ambiguity_check=ambiguity_check)
|
||||||
|
|
||||||
|
@torch.compile(dynamic=False, fullgraph=True, backend="eager")
|
||||||
|
def h1(xs, ambiguity_check=False):
|
||||||
|
u0, u1, u2 = xs.tolist()
|
||||||
|
y = torch.empty_strided([4, 4, 4], [u0, u0, u0]) # no ordering
|
||||||
|
return y.dim_order(ambiguity_check=ambiguity_check)
|
||||||
|
|
||||||
|
# check that for functions permuting contiguous input, the original stride is recovered with dim_order.
|
||||||
|
def test(x):
|
||||||
|
stride_inp = tuple(x.stride())
|
||||||
|
f_out, f_order = f(x)
|
||||||
|
self.assertEqual(stride_inp, tuple(f_out.stride(i) for i in f_order))
|
||||||
|
|
||||||
|
# shape: [4, u0, 5, u1]
|
||||||
|
x0 = torch.randn(4, 1, 5, 2)
|
||||||
|
torch._dynamo.decorators.mark_unbacked(x0, 1)
|
||||||
|
torch._dynamo.decorators.mark_unbacked(x0, 3)
|
||||||
|
test(x0)
|
||||||
|
|
||||||
|
# shape: [u0, u1, u2, u3]
|
||||||
|
x1 = torch.randn(4, 1, 5, 2)
|
||||||
|
for i in range(x1.ndim):
|
||||||
|
torch._dynamo.decorators.mark_unbacked(x1, i)
|
||||||
|
test(x1)
|
||||||
|
|
||||||
|
# custom strides (all integers)
|
||||||
|
x2 = torch.randn(10000)
|
||||||
|
x2 = x2.as_strided([4, 4, 4, 4], [1, 2, 4, 8])
|
||||||
|
assert g(x2) == (3, 2, 1, 0)
|
||||||
|
|
||||||
|
# custom unbacked strides with no ordering: ambiguity check should raise
|
||||||
|
xs = torch.tensor([2, 3, 4])
|
||||||
|
h0(xs)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
torch._dynamo.exc.TorchRuntimeError,
|
||||||
|
r"The tensor does not have unique dim order.",
|
||||||
|
):
|
||||||
|
h0(xs, ambiguity_check=True)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
torch._dynamo.exc.TorchRuntimeError,
|
||||||
|
r"The tensor does not have unique dim order.",
|
||||||
|
):
|
||||||
|
h1(xs, ambiguity_check=True)
|
||||||
|
|
||||||
def test_str_format_assert1(self):
|
def test_str_format_assert1(self):
|
||||||
@torch.compile(backend="eager", fullgraph=True)
|
@torch.compile(backend="eager", fullgraph=True)
|
||||||
def fn(img):
|
def fn(img):
|
||||||
|
|
|
||||||
|
|
@ -579,7 +579,7 @@ def view_copy_dtype(
|
||||||
def _get_shape_permutation_like(
|
def _get_shape_permutation_like(
|
||||||
self: torch.Tensor,
|
self: torch.Tensor,
|
||||||
) -> tuple[utils.ShapeType, utils.StrideType]:
|
) -> tuple[utils.ShapeType, utils.StrideType]:
|
||||||
physical_layout = utils.compute_elementwise_output_logical_to_physical_perm(self)
|
physical_layout, _ = utils.compute_elementwise_output_logical_to_physical_perm(self)
|
||||||
shape = [self.shape[l] for l in physical_layout]
|
shape = [self.shape[l] for l in physical_layout]
|
||||||
|
|
||||||
permutation = [0] * len(shape)
|
permutation = [0] * len(shape)
|
||||||
|
|
|
||||||
|
|
@ -3465,7 +3465,7 @@ def meta_index_Tensor(self, indices):
|
||||||
# Note that perm here is the reverse of the 'perm_' decided by
|
# Note that perm here is the reverse of the 'perm_' decided by
|
||||||
# TensorIteratorBase::reorder_dimensions
|
# TensorIteratorBase::reorder_dimensions
|
||||||
restrided_self = _restride_src(self)
|
restrided_self = _restride_src(self)
|
||||||
perm = utils.compute_elementwise_output_logical_to_physical_perm(restrided_self)
|
perm, _ = utils.compute_elementwise_output_logical_to_physical_perm(restrided_self)
|
||||||
|
|
||||||
# Follow TensorIteratorBase::allocate_or_resize_outputs
|
# Follow TensorIteratorBase::allocate_or_resize_outputs
|
||||||
if list(perm) != list(range(len(perm))):
|
if list(perm) != list(range(len(perm))):
|
||||||
|
|
|
||||||
|
|
@ -404,7 +404,7 @@ def _prim_elementwise_meta(
|
||||||
utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
|
utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
|
||||||
utils.check_same_shape(*args_, allow_cpu_scalar_tensors=True)
|
utils.check_same_shape(*args_, allow_cpu_scalar_tensors=True)
|
||||||
|
|
||||||
l2p_perm = utils.compute_elementwise_output_logical_to_physical_perm(*args_)
|
l2p_perm, _ = utils.compute_elementwise_output_logical_to_physical_perm(*args_)
|
||||||
shape = utils.extract_shape(*args_, allow_cpu_scalar_tensors=True)
|
shape = utils.extract_shape(*args_, allow_cpu_scalar_tensors=True)
|
||||||
|
|
||||||
# Acquires the dtype
|
# Acquires the dtype
|
||||||
|
|
|
||||||
|
|
@ -534,12 +534,9 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool:
|
||||||
# This is also INCORRECT because it does not model TensorIterator's
|
# This is also INCORRECT because it does not model TensorIterator's
|
||||||
# short-circuit, which can cause different strides.
|
# short-circuit, which can cause different strides.
|
||||||
def compute_elementwise_output_logical_to_physical_perm(
|
def compute_elementwise_output_logical_to_physical_perm(
|
||||||
*tensors, _skip_checks=False
|
*tensors, _skip_checks=False, ambiguity_check=False
|
||||||
) -> list[int]:
|
) -> tuple[list[int], bool]:
|
||||||
from torch.fx.experimental.symbolic_shapes import (
|
from torch.fx.experimental.symbolic_shapes import guard_or_false
|
||||||
guard_or_false,
|
|
||||||
guard_size_oblivious,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not _skip_checks and len(tensors) == 0:
|
if not _skip_checks and len(tensors) == 0:
|
||||||
msg = "Can't compute elementwise output strides for zero tensors!"
|
msg = "Can't compute elementwise output strides for zero tensors!"
|
||||||
|
|
@ -558,15 +555,15 @@ def compute_elementwise_output_logical_to_physical_perm(
|
||||||
|
|
||||||
# Short-circuits for CPU scalar case
|
# Short-circuits for CPU scalar case
|
||||||
if len(tensors) == 0:
|
if len(tensors) == 0:
|
||||||
return []
|
return [], False
|
||||||
|
|
||||||
# Short-circuits for shapes with zero or one dimensions
|
# Short-circuits for shapes with zero or one dimensions
|
||||||
# TODO: are these necessary?
|
# TODO: are these necessary?
|
||||||
ndim = tensors[0].ndim
|
ndim = tensors[0].ndim
|
||||||
if ndim == 0:
|
if ndim == 0:
|
||||||
return []
|
return [], False
|
||||||
if ndim == 1:
|
if ndim == 1:
|
||||||
return [0]
|
return [0], False
|
||||||
|
|
||||||
# Short-circuits if contiguous or channels last, following the fake fast path.
|
# Short-circuits if contiguous or channels last, following the fake fast path.
|
||||||
# This reduces the number of guards we end up making
|
# This reduces the number of guards we end up making
|
||||||
|
|
@ -584,42 +581,40 @@ def compute_elementwise_output_logical_to_physical_perm(
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_contiguous and not is_channels_last:
|
if is_contiguous and not is_channels_last:
|
||||||
return list(range(ndim))
|
return list(range(ndim)), False
|
||||||
|
|
||||||
if is_channels_last and not is_contiguous:
|
if is_channels_last and not is_contiguous:
|
||||||
return [0, *list(range(2, ndim)), 1]
|
return [0, *list(range(2, ndim)), 1], False
|
||||||
|
|
||||||
shape = tensors[0].shape
|
shape = tensors[0].shape
|
||||||
|
|
||||||
def should_swap(idx_a, idx_b):
|
def should_swap(idx_a, idx_b):
|
||||||
|
def ge(a, b):
|
||||||
|
"""
|
||||||
|
Returns true if a is symbolically greater than or equal to b, assuming a >= 0, b >= 0.
|
||||||
|
"""
|
||||||
|
if guard_or_false(b == 0):
|
||||||
|
return True
|
||||||
|
elif guard_or_false(a == 0):
|
||||||
|
return False
|
||||||
|
return guard_or_false(a >= b) or guard_or_false(a % b == 0)
|
||||||
|
|
||||||
for tensor in tensors:
|
for tensor in tensors:
|
||||||
stride_a = tensor.stride()[idx_a]
|
stride_a = tensor.stride()[idx_a]
|
||||||
stride_b = tensor.stride()[idx_b]
|
stride_b = tensor.stride()[idx_b]
|
||||||
if guard_size_oblivious(stride_a == 0) or guard_size_oblivious(
|
|
||||||
stride_b == 0
|
if guard_or_false(stride_a == 0) or guard_or_false(stride_b == 0):
|
||||||
):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if guard_or_false(stride_a == stride_b):
|
if guard_or_false(stride_a == stride_b):
|
||||||
if guard_size_oblivious(shape[idx_a] > shape[idx_b]):
|
if ge(shape[idx_b], shape[idx_a]):
|
||||||
|
continue
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
# when stride_a = 1, we want stride_a < stride_b to be TRUE
|
if ge(stride_b, stride_a):
|
||||||
# when stride_b = 1, we want stride_a < stride_b to be FALSE
|
|
||||||
elif guard_or_false(stride_a == 1):
|
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
elif guard_or_false(stride_b == 1):
|
if ge(stride_a, stride_b):
|
||||||
return 1
|
|
||||||
|
|
||||||
if guard_size_oblivious(stride_a < stride_b):
|
|
||||||
return -1
|
|
||||||
|
|
||||||
if guard_size_oblivious(stride_a > stride_b):
|
|
||||||
return 1
|
|
||||||
|
|
||||||
# stride_a == stride_b
|
|
||||||
if guard_size_oblivious(shape[idx_a] > shape[idx_b]):
|
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
# Note: this case is hit if all strides are zero,
|
# Note: this case is hit if all strides are zero,
|
||||||
|
|
@ -644,7 +639,16 @@ def compute_elementwise_output_logical_to_physical_perm(
|
||||||
elif comparison < 0:
|
elif comparison < 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
return list(reversed(perm))
|
# verify we've imposed ordering if ambiguity_check=True
|
||||||
|
raise_ambiguous = False
|
||||||
|
if ambiguity_check:
|
||||||
|
for i, j in zip(range(ndim - 1), range(1, ndim)):
|
||||||
|
order = should_swap(perm[i], perm[j])
|
||||||
|
if order != -1:
|
||||||
|
raise_ambiguous = True
|
||||||
|
break
|
||||||
|
|
||||||
|
return list(reversed(perm)), raise_ambiguous
|
||||||
|
|
||||||
|
|
||||||
def compute_elementwise_output_strides(*tensors) -> tuple[int, ...]:
|
def compute_elementwise_output_strides(*tensors) -> tuple[int, ...]:
|
||||||
|
|
@ -674,7 +678,7 @@ def compute_elementwise_output_strides(*tensors) -> tuple[int, ...]:
|
||||||
if ndim == 1:
|
if ndim == 1:
|
||||||
return (1,)
|
return (1,)
|
||||||
|
|
||||||
logical_to_physical_perm = compute_elementwise_output_logical_to_physical_perm(
|
logical_to_physical_perm, _ = compute_elementwise_output_logical_to_physical_perm(
|
||||||
*tensors, _skip_checks=True
|
*tensors, _skip_checks=True
|
||||||
)
|
)
|
||||||
permuted_shape = apply_perm(shape, logical_to_physical_perm) # to physical
|
permuted_shape = apply_perm(shape, logical_to_physical_perm) # to physical
|
||||||
|
|
|
||||||
|
|
@ -5111,7 +5111,7 @@ def empty_like(
|
||||||
)
|
)
|
||||||
|
|
||||||
# memory_format == torch.preserve_format
|
# memory_format == torch.preserve_format
|
||||||
logical_to_physical_perm = (
|
logical_to_physical_perm, _ = (
|
||||||
utils.compute_elementwise_output_logical_to_physical_perm(a)
|
utils.compute_elementwise_output_logical_to_physical_perm(a)
|
||||||
)
|
)
|
||||||
# identity perm is [2, 1, 0]
|
# identity perm is [2, 1, 0]
|
||||||
|
|
|
||||||
|
|
@ -1585,17 +1585,19 @@ class Tensor(torch._C.TensorBase):
|
||||||
If any two dimensions have the same stride, swapping these dimensions won't
|
If any two dimensions have the same stride, swapping these dimensions won't
|
||||||
change how data is accessed, leading to multiple correct dimension orders.
|
change how data is accessed, leading to multiple correct dimension orders.
|
||||||
"""
|
"""
|
||||||
|
from torch.fx.experimental.symbolic_shapes import guard_or_false
|
||||||
|
|
||||||
sizes = tensor.size()
|
sizes = tensor.size()
|
||||||
strides = tensor.stride()
|
strides = tensor.stride()
|
||||||
|
|
||||||
# Check if there are any duplicate strides
|
# Check if there are any duplicate strides
|
||||||
has_duplicate_strides = any(
|
has_duplicate_strides = any(
|
||||||
earlier == later for earlier, later in zip(strides, strides[1:])
|
guard_or_false(earlier == later)
|
||||||
|
for earlier, later in zip(strides, strides[1:])
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if there are any singleton dimensions
|
# Check if there are any singleton dimensions
|
||||||
has_singleton_dims = any(size == 1 for size in sizes)
|
has_singleton_dims = any(guard_or_false(size == 1) for size in sizes)
|
||||||
|
|
||||||
return has_duplicate_strides or has_singleton_dims
|
return has_duplicate_strides or has_singleton_dims
|
||||||
|
|
||||||
|
|
@ -1615,7 +1617,14 @@ class Tensor(torch._C.TensorBase):
|
||||||
|
|
||||||
import torch._prims_common as utils
|
import torch._prims_common as utils
|
||||||
|
|
||||||
return tuple(utils.compute_elementwise_output_logical_to_physical_perm(self))
|
out_perm, raise_ambiguity = (
|
||||||
|
utils.compute_elementwise_output_logical_to_physical_perm(
|
||||||
|
self, ambiguity_check=ambiguity_check
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if raise_ambiguity:
|
||||||
|
raise RuntimeError("The tensor does not have unique dim order.")
|
||||||
|
return tuple(out_perm)
|
||||||
|
|
||||||
def _update_names(self, names, inplace):
|
def _update_names(self, names, inplace):
|
||||||
if has_torch_function_unary(self):
|
if has_torch_function_unary(self):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user