[dynamic shapes] use try-catch instead of guard_or_true for reshape_view_helper (#152638)

Test Plan: test_export

Differential Revision: D74033649

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152638
Approved by: https://github.com/laithsakka
This commit is contained in:
Pian Pawakapan 2025-05-06 00:54:24 +00:00 committed by PyTorch MergeBot
parent d197228d43
commit 13dcf80a53
2 changed files with 33 additions and 8 deletions

View File

@ -343,6 +343,23 @@ class TestDynamismExpression(TestCase):
seq_len = torch.tensor(5)
torch.export.export(MySlice(), args=(x, seq_len))
@torch.fx.experimental._config.patch(backed_size_oblivious=True)
def test_reshape_view_backed_size_oblivious(self):
N = 3
class MyModel(torch.nn.Module):
def forward(self, x):
y = x[:-1, :] # [s0 - 1, 32]
stacked = torch.stack([y] * N, dim=0) # [N * (s0 - 1), 32]
reshaped = stacked.reshape(-1, N, 32) # [(s0 - 1), N, 32]
return reshaped
inps = (torch.randn(10, 32),)
spec = {
"x": (Dim.AUTO, Dim.STATIC),
}
ep = export(MyModel(), inps, dynamic_shapes=spec)
def test_export_constraints_error(self):
class ConflictingConstraints(torch.nn.Module):
def forward(self, x):
@ -4421,10 +4438,10 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
"The following call raised this error(.*\n)+"
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
"To fix the error, insert one of the following checks before this call.*:\n"
f".*{re.escape('torch._check((items[1] % items[2]) != 0)')}.*\n"
f".*{re.escape('torch._check((items[1] % items[2]) == 0)')}(.*\n)+"
f".*{re.escape('torch._check((items[1] % items[2]) == 0)')}.*\n"
f".*{re.escape('torch._check((items[1] % items[2]) != 0)')}(.*\n)+"
f".*{re.escape('(These suggested fixes were derived by replacing `u1` with items[1]')}"
f".*{re.escape('or r.shape[1], `u2` with items[2] in Ne(Mod(u1, u2), 0) and its negation.')}",
f".*{re.escape('or r.shape[1], `u2` with items[2] in Eq(Mod(u1, u2), 0) and its negation.')}",
):
export(N(), (t,), strict=strict)

View File

@ -3731,7 +3731,11 @@ def repeat(a: Tensor, *repeat_shape) -> Tensor:
def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorLikeType:
from torch._dynamo.exc import UserError, UserErrorType
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
from torch.fx.experimental.symbolic_shapes import (
guard_or_false,
guard_or_true,
GuardOnDataDependentSymNode,
)
# Creates a valid shape
shape = utils.extract_shape_from_varargs(shape, validate=False)
@ -3834,12 +3838,16 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
accum = a_.shape[idx]
end = idx
while guard_or_true(accum % length != 0):
deferred.append(lambda: bool(accum % length != 0))
while True:
try:
if accum % length == 0:
break
except GuardOnDataDependentSymNode:
deferred.append(lambda: bool(accum % length == 0))
if end == a_.ndim - 1:
maybe_throw_dde()
end = end + 1
accum = accum * a_.shape[end]
end += 1
accum *= a_.shape[end]
if end != idx:
# NOTE: in this case multiple dimensions must be flatten to create the desired dimension
# This flattening is why reshape sometimes creates a copy -- because flattening