mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamic shapes] use try-catch instead of guard_or_true for reshape_view_helper (#152638)
Test Plan: test_export Differential Revision: D74033649 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152638 Approved by: https://github.com/laithsakka
This commit is contained in:
parent
d197228d43
commit
13dcf80a53
|
|
@ -343,6 +343,23 @@ class TestDynamismExpression(TestCase):
|
|||
seq_len = torch.tensor(5)
|
||||
torch.export.export(MySlice(), args=(x, seq_len))
|
||||
|
||||
@torch.fx.experimental._config.patch(backed_size_oblivious=True)
|
||||
def test_reshape_view_backed_size_oblivious(self):
|
||||
N = 3
|
||||
|
||||
class MyModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
y = x[:-1, :] # [s0 - 1, 32]
|
||||
stacked = torch.stack([y] * N, dim=0) # [N * (s0 - 1), 32]
|
||||
reshaped = stacked.reshape(-1, N, 32) # [(s0 - 1), N, 32]
|
||||
return reshaped
|
||||
|
||||
inps = (torch.randn(10, 32),)
|
||||
spec = {
|
||||
"x": (Dim.AUTO, Dim.STATIC),
|
||||
}
|
||||
ep = export(MyModel(), inps, dynamic_shapes=spec)
|
||||
|
||||
def test_export_constraints_error(self):
|
||||
class ConflictingConstraints(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
|
@ -4421,10 +4438,10 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
"The following call raised this error(.*\n)+"
|
||||
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
|
||||
"To fix the error, insert one of the following checks before this call.*:\n"
|
||||
f".*{re.escape('torch._check((items[1] % items[2]) != 0)')}.*\n"
|
||||
f".*{re.escape('torch._check((items[1] % items[2]) == 0)')}(.*\n)+"
|
||||
f".*{re.escape('torch._check((items[1] % items[2]) == 0)')}.*\n"
|
||||
f".*{re.escape('torch._check((items[1] % items[2]) != 0)')}(.*\n)+"
|
||||
f".*{re.escape('(These suggested fixes were derived by replacing `u1` with items[1]')}"
|
||||
f".*{re.escape('or r.shape[1], `u2` with items[2] in Ne(Mod(u1, u2), 0) and its negation.')}",
|
||||
f".*{re.escape('or r.shape[1], `u2` with items[2] in Eq(Mod(u1, u2), 0) and its negation.')}",
|
||||
):
|
||||
export(N(), (t,), strict=strict)
|
||||
|
||||
|
|
|
|||
|
|
@ -3731,7 +3731,11 @@ def repeat(a: Tensor, *repeat_shape) -> Tensor:
|
|||
|
||||
def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorLikeType:
|
||||
from torch._dynamo.exc import UserError, UserErrorType
|
||||
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
guard_or_false,
|
||||
guard_or_true,
|
||||
GuardOnDataDependentSymNode,
|
||||
)
|
||||
|
||||
# Creates a valid shape
|
||||
shape = utils.extract_shape_from_varargs(shape, validate=False)
|
||||
|
|
@ -3834,12 +3838,16 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
|
|||
|
||||
accum = a_.shape[idx]
|
||||
end = idx
|
||||
while guard_or_true(accum % length != 0):
|
||||
deferred.append(lambda: bool(accum % length != 0))
|
||||
while True:
|
||||
try:
|
||||
if accum % length == 0:
|
||||
break
|
||||
except GuardOnDataDependentSymNode:
|
||||
deferred.append(lambda: bool(accum % length == 0))
|
||||
if end == a_.ndim - 1:
|
||||
maybe_throw_dde()
|
||||
end = end + 1
|
||||
accum = accum * a_.shape[end]
|
||||
end += 1
|
||||
accum *= a_.shape[end]
|
||||
if end != idx:
|
||||
# NOTE: in this case multiple dimensions must be flatten to create the desired dimension
|
||||
# This flattening is why reshape sometimes creates a copy -- because flattening
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user