Revert "[dynamic shapes] guard_or_false for _reshape_view_helper, utils._infer_size for wildcard dims (#150127)"

This reverts commit 1dd2033c0a.

Reverted https://github.com/pytorch/pytorch/pull/150127 on behalf of https://github.com/clee2000 due to maybe caused export test to fail? export/test_draft_export.py::TestDraftExport::test_masked_linear [GH job link](https://github.com/pytorch/pytorch/actions/runs/14538768138/job/40794985504) [HUD commit link](1dd2033c0a), bad TD ([comment](https://github.com/pytorch/pytorch/pull/150127#issuecomment-2816232086))
This commit is contained in:
PyTorch MergeBot 2025-04-18 21:38:47 +00:00
parent bd77c3e054
commit 97d97aef24
3 changed files with 77 additions and 79 deletions

View File

@ -4306,7 +4306,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
def forward(self, t): def forward(self, t):
items = [t[i].item() for i in range(t.numel())] items = [t[i].item() for i in range(t.numel())]
r = torch.randn([items[0], items[1]]) r = torch.randn([items[0], items[1]])
# Could not guard on data-dependent expression Ne(Mod(u1, u2), 0) # Could not guard on data-dependent expression Eq(u2, -1)
return r.view(items[0], items[2]) return r.view(items[0], items[2])
M = M_v0 M = M_v0
@ -4315,10 +4315,9 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
"The following call raised this error(.*\n)+" "The following call raised this error(.*\n)+"
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+" f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
"To fix the error, insert one of the following checks before this call.*:\n" "To fix the error, insert one of the following checks before this call.*:\n"
f".*{re.escape('torch._check((items[1] % items[2]) != 0)')}.*\n" f".*{re.escape('torch._check(items[2] == (-1))')}.*\n"
f".*{re.escape('torch._check((items[1] % items[2]) == 0)')}(.*\n)+" f".*{re.escape('torch._check(items[2] != (-1))')}(.*\n)+"
f".*{re.escape('(These suggested fixes were derived by replacing `u1` with items[1]')}" f".*{re.escape('(These suggested fixes were derived by replacing `u2` with items[2] in Eq(u2, -1) and its negation.)')}",
f".*{re.escape('or r.shape[1], `u2` with items[2] in Ne(Mod(u1, u2), 0) and its negation.')}",
): ):
export(N(), (t,), strict=strict) export(N(), (t,), strict=strict)
@ -4326,12 +4325,59 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
def forward(self, t): def forward(self, t):
items = [t[i].item() for i in range(t.numel())] items = [t[i].item() for i in range(t.numel())]
r = torch.randn([items[0], items[1]]) r = torch.randn([items[0], items[1]])
# TODO(pianpwk): this isn't the suggested fixes. # Could not guard on data-dependent expression Eq(u2, -1)
# fix issue with % being interpreted as PythonMod instead of Mod torch._check(items[2] != -1)
torch._check(items[1] == items[2]) # Could not guard on data-dependent expression u2 >= 0
return r.view(items[0], items[2]) return r.view(items[0], items[2])
M = M_v1 M = M_v1
with self.assertRaisesRegex(
error_type,
"The following call raised this error(.*\n)+"
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
"To fix the error, insert one of the following checks before this call.*:\n"
f".*{re.escape('You can add either: torch._check_is_size(u2) or torch._check(u2>=0) Note: torch._check_is_size(u2) could prevent data dependent errors that happen in a guard_size_oblivious(..) context by opting into guard_size_oblivious reasoning. See documentation on guard_size_oblivious for more details: https://pytorch.org/docs/stable/generated/torch.fx.experimental.symbolic_shapes.guard_size_oblivious.html')}.*\n"
f".*{re.escape('torch._check(items[2] < 0)')}(.*\n)+"
f".*{re.escape('(These suggested fixes were derived by replacing `u2` with items[2] in u2 >= 0 and its negation.)')}",
):
export(N(), (t,), strict=strict)
class M_v2(torch.nn.Module):
def forward(self, t):
items = [t[i].item() for i in range(t.numel())]
r = torch.randn([items[0], items[1]])
# Could not guard on data-dependent expression Eq(u2, -1)
torch._check(items[2] != -1)
# Could not guard on data-dependent expression u2 >= 0
torch._check(items[2] >= 0)
# Could not guard on data-dependent expression Eq(u1, u2)
return r.view(items[0], items[2])
M = M_v2
with self.assertRaisesRegex(
error_type,
"The following call raised this error(.*\n)+"
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
"To fix the error, insert one of the following checks before this call.*:\n"
f".*{re.escape('torch._check(items[2] == items[1])')}.*\n"
f".*{re.escape('torch._check(items[2] != items[1])')}(.*\n)+"
f".*{re.escape('(These suggested fixes were derived by replacing `u1` with items[1] or r.shape[1], `u2` with items[2] in Eq(u2, u1) and its negation.)')}",
):
export(N(), (t,), strict=strict)
class M_v3(torch.nn.Module):
def forward(self, t):
items = [t[i].item() for i in range(t.numel())]
r = torch.randn([items[0], items[1]])
# Could not guard on data-dependent expression Eq(u2, -1)
torch._check(items[2] != -1)
# Could not guard on data-dependent expression u2 >= 0
torch._check(items[2] >= 0)
# Could not guard on data-dependent expression Eq(u1, u2)
torch._check(items[2] == r.shape[1])
return r.view(items[0], items[2])
M = M_v3
export(N(), (t,), strict=strict) export(N(), (t,), strict=strict)
def test_suggested_fixes_for_data_dependent_errors_puzzlers(self): def test_suggested_fixes_for_data_dependent_errors_puzzlers(self):
@ -4443,29 +4489,6 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
fixes=[], # nothing to fix! fixes=[], # nothing to fix!
) )
def test_simple_unbacked_view(self):
class Foo(torch.nn.Module):
def forward(self, x):
u0 = x.item()
y = torch.empty(5, u0)
return y.view(u0, 5) # [5, u0] -> [u0, 5]
ep = export(Foo(), (torch.tensor([9]),))
self.assertEqual(ep.module()(torch.tensor([8])).size(0), 8)
self.assertEqual(ep.module()(torch.tensor([5])).size(0), 5)
class Foov2(torch.nn.Module):
def forward(self, xs):
xsl = xs.tolist()
a, b = xsl
x = torch.zeros(a)
return x.reshape(b)
xs = torch.tensor([4, 4])
ep = export(Foov2(), (xs,))
self.assertEqual(ep.module()(xs).size(0), 4)
self.assertEqual(ep.module()(torch.tensor([5, 5])).size(0), 5)
def test_no_suggested_fixes_for_data_dependent_errors(self): def test_no_suggested_fixes_for_data_dependent_errors(self):
# suggested fixes for data-dependent errors only work in non-strict mode # suggested fixes for data-dependent errors only work in non-strict mode
strict = False strict = False
@ -7388,19 +7411,22 @@ def forward(self, b_a_buffer, x):
len([node for node in gm.graph.nodes if node.op == "placeholder"]), 2 len([node for node in gm.graph.nodes if node.op == "placeholder"]), 2
) )
def test_no_check_is_size_error(self): def test_check_is_size_error(self):
class Module(torch.nn.Module): class Module(torch.nn.Module):
def forward(self, x): def forward(self, x):
a = x.item() a = x.item()
# We cannot automatically infer a is a size here because view
# accepts -1
return torch.randn(24).view(a, 4) return torch.randn(24).view(a, 4)
f = Module() f = Module()
ep = export(f, (torch.tensor(6),)) if is_non_strict_test(self._testMethodName):
ep.module()(torch.tensor(6)) error = torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode
with self.assertRaisesRegex( else:
RuntimeError, r"Runtime assertion failed for .* u.* 6" error = torch._dynamo.exc.UserError
): error_msg = r"Could not guard on data-dependent expression"
ep.module()(torch.tensor(5)) with self.assertRaisesRegex(error, error_msg):
_ = export(f, (torch.tensor(6),))
def test_is_non_negative_check_function(self): def test_is_non_negative_check_function(self):
import sympy as sp import sympy as sp
@ -13244,7 +13270,7 @@ def forward(self, x):
node.target == torch.ops.aten._assert_scalar.default node.target == torch.ops.aten._assert_scalar.default
for node in ep.graph.nodes for node in ep.graph.nodes
].count(True) ].count(True)
self.assertEqual(num_asserts, 2) self.assertEqual(num_asserts, 1)
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
ep.module()(torch.randn(4, 2)) ep.module()(torch.randn(4, 2))

View File

@ -924,29 +924,24 @@ def infer_size(shape: ShapeType, numel: int) -> tuple[int, ...]:
Infers the size of a dim with size -1, if it exists. Infers the size of a dim with size -1, if it exists.
Also checks that new shape is compatible with the number of elements. Also checks that new shape is compatible with the number of elements.
""" """
from torch.fx.experimental.symbolic_shapes import definitely_true, guard_or_false
dim = None dim = None
newsize = 1 newsize = 1
for i, d in enumerate(shape): for i, d in enumerate(shape):
if guard_or_false(d == -1): if d == -1:
torch._check(dim is None, lambda: "only one dimension can be inferred") torch._check(dim is None, lambda: "only one dimension can be inferred")
dim = i dim = i
else: elif d >= 0:
torch._check(
d >= 0,
lambda: (
f"invalid shape dimension {d}. If this was symbolic, it was assumed to not be -1."
"If this was meant to be inferred, please explicitly pass in -1."
),
)
newsize *= d newsize *= d
else:
torch._check(False, lambda: f"invalid shape dimension {d}")
if dim is None: if dim is None:
torch._check( torch._check(
numel == newsize, numel == newsize,
lambda: f"shape '{list(shape)}' is invalid for input of size {numel}", lambda: f"shape '{list(shape)}' is invalid for input of size {numel}",
) )
else: else:
from torch.fx.experimental.symbolic_shapes import definitely_true
torch._check( torch._check(
newsize != 0, newsize != 0,
lambda: ( lambda: (

View File

@ -3717,8 +3717,7 @@ def repeat(a: Tensor, *repeat_shape) -> Tensor:
def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorLikeType: def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorLikeType:
from torch._dynamo.exc import UserError, UserErrorType from torch.fx.experimental.symbolic_shapes import guard_size_oblivious, sym_eq
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
# Creates a valid shape # Creates a valid shape
shape = utils.extract_shape_from_varargs(shape, validate=False) shape = utils.extract_shape_from_varargs(shape, validate=False)
@ -3727,7 +3726,7 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
shape = utils.infer_size(shape, a.numel()) shape = utils.infer_size(shape, a.numel())
# Special-cases tensors with no elements # Special-cases tensors with no elements
if guard_or_false(a.numel() == 0): if guard_size_oblivious(a.numel() == 0):
return as_strided(a, shape, utils.make_contiguous_strides_for(shape)) return as_strided(a, shape, utils.make_contiguous_strides_for(shape))
# Special-cases reshaping zero dim tensors # Special-cases reshaping zero dim tensors
@ -3763,12 +3762,6 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
return torch.as_strided(a, [dim0, dim1], [dim1, 1]) return torch.as_strided(a, [dim0, dim1], [dim1, 1])
# Handles general case: a 1+D tensor reshaped into a distinct 1+D shape # Handles general case: a 1+D tensor reshaped into a distinct 1+D shape
shape_numel = reduce(operator.mul, shape, 1)
torch._check(
a.numel() == shape_numel,
f"Could not reshape a tensor with shape {a.shape} as a tensor with shape {shape}!",
)
deferred: list[Callable[[], bool]] = []
# NOTE [Reshape Algorithm] # NOTE [Reshape Algorithm]
# This algorithm works by attempting to greedily construct the desired dimensions in # This algorithm works by attempting to greedily construct the desired dimensions in
@ -3801,30 +3794,16 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
continue continue
# Skips dimensions that are already the correct length # Skips dimensions that are already the correct length
if guard_or_false(length == a_.shape[idx]): if guard_size_oblivious(length == a_.shape[idx]):
idx = idx + 1 idx = idx + 1
continue continue
# Gathers enough original dimensions such that this new dimension can be created # Gathers enough original dimensions such that this new dimension can be created
# Note that this accumulation will terminate because we've verified a and the shape # Note that this accumulation will terminate because we've verified a and the shape
# specify the same number of elements above # specify the same number of elements above
def maybe_throw_dde():
# NOTE: if you've hit a data-dependent error here, it's because in trying to accumulate input
# tensor dimensions to match the target shape (length), we've hit data-dependent errors testing
# divisibility (accum % length != 0), and have deferred raising them, in the hope that we'd
# figure out a valid reshape later in the loop.
# But we failed, either by running out of dimensions, or we couldn't figure out the strides,
# and we've decided to re-raise to either graph break out, or provide the exact guard so the user
# can torch._check() to avoid this.
for f in deferred:
f()
accum = a_.shape[idx] accum = a_.shape[idx]
end = idx end = idx
while guard_or_true(accum % length != 0): while guard_size_oblivious(accum % length != 0):
deferred.append(lambda: bool(accum % length != 0))
if end == a_.ndim - 1:
maybe_throw_dde()
end = end + 1 end = end + 1
accum = accum * a_.shape[end] accum = accum * a_.shape[end]
if end != idx: if end != idx:
@ -3838,15 +3817,13 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
if allow_copy: if allow_copy:
return prims.reshape(a, shape) return prims.reshape(a, shape)
maybe_throw_dde()
msg = f"Cannot view a tensor with shape {a.shape} and strides {a.stride()} as a tensor with shape {shape}!" msg = f"Cannot view a tensor with shape {a.shape} and strides {a.stride()} as a tensor with shape {shape}!"
raise ValueError(msg) raise ValueError(msg)
a_ = flatten(a_, idx, end) a_ = flatten(a_, idx, end)
# Splits the (possibly flattened) dimension to create the desired dim length. # Splits the (possibly flattened) dimension to create the desired dim length
# guard_or_true is safe due to the tail unsqueeze routine. if guard_size_oblivious(accum != length):
if guard_or_true(accum != length):
a_ = prims.split_dim(a_, idx, length) a_ = prims.split_dim(a_, idx, length)
idx = idx + 1 idx = idx + 1