diff --git a/test/export/test_draft_export.py b/test/export/test_draft_export.py index 03383a02d0f..f06f837c62e 100644 --- a/test/export/test_draft_export.py +++ b/test/export/test_draft_export.py @@ -323,7 +323,7 @@ class TestDraftExport(TestCase): self.assertEqual( report.failures[0].failure_type, FailureType.DATA_DEPENDENT_ERROR ) - self.assertEqual(report.failures[0].data["expr"], "Eq(2*u1, 10)") + self.assertEqual(report.failures[0].data["expr"], "Eq(9380*u1, 0)") def test_dedup_data_dependent_failure(self): class M(torch.nn.Module): @@ -480,6 +480,7 @@ class TestDraftExport(TestCase): return torch.nn.functional.linear(masked, weight, bias) x = torch.zeros(10) + x[0] += 1 inp = (torch.randn(10, 8, 7), x, torch.randn(25, 7), torch.randn(25)) draft_ep = draft_export(M(), inp) ep = export(M(), inp) diff --git a/test/export/test_export.py b/test/export/test_export.py index 45b1bf83d43..550e5ab9050 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -4301,7 +4301,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): def forward(self, t): items = [t[i].item() for i in range(t.numel())] r = torch.randn([items[0], items[1]]) - # Could not guard on data-dependent expression Eq(u2, -1) + # Could not guard on data-dependent expression Ne(Mod(u1, u2), 0) return r.view(items[0], items[2]) M = M_v0 @@ -4310,9 +4310,10 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): "The following call raised this error(.*\n)+" f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+" "To fix the error, insert one of the following checks before this call.*:\n" - f".*{re.escape('torch._check(items[2] == (-1))')}.*\n" - f".*{re.escape('torch._check(items[2] != (-1))')}(.*\n)+" - f".*{re.escape('(These suggested fixes were derived by replacing `u2` with items[2] in Eq(u2, -1) and its negation.)')}", + f".*{re.escape('torch._check((items[1] % items[2]) != 0)')}.*\n" + f".*{re.escape('torch._check((items[1] % items[2]) == 0)')}(.*\n)+" + f".*{re.escape('(These suggested fixes were derived by replacing `u1` with items[1]')}" + f".*{re.escape('or r.shape[1], `u2` with items[2] in Ne(Mod(u1, u2), 0) and its negation.')}", ): export(N(), (t,), strict=strict) @@ -4320,59 +4321,12 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): def forward(self, t): items = [t[i].item() for i in range(t.numel())] r = torch.randn([items[0], items[1]]) - # Could not guard on data-dependent expression Eq(u2, -1) - torch._check(items[2] != -1) - # Could not guard on data-dependent expression u2 >= 0 + # TODO(pianpwk): this isn't the suggested fixes. + # fix issue with % being interpreted as PythonMod instead of Mod + torch._check(items[1] == items[2]) return r.view(items[0], items[2]) M = M_v1 - with self.assertRaisesRegex( - error_type, - "The following call raised this error(.*\n)+" - f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+" - "To fix the error, insert one of the following checks before this call.*:\n" - f".*{re.escape('You can add either: torch._check_is_size(u2) or torch._check(u2>=0) Note: torch._check_is_size(u2) could prevent data dependent errors that happen in a guard_size_oblivious(..) context by opting into guard_size_oblivious reasoning. See documentation on guard_size_oblivious for more details: https://pytorch.org/docs/stable/generated/torch.fx.experimental.symbolic_shapes.guard_size_oblivious.html')}.*\n" - f".*{re.escape('torch._check(items[2] < 0)')}(.*\n)+" - f".*{re.escape('(These suggested fixes were derived by replacing `u2` with items[2] in u2 >= 0 and its negation.)')}", - ): - export(N(), (t,), strict=strict) - - class M_v2(torch.nn.Module): - def forward(self, t): - items = [t[i].item() for i in range(t.numel())] - r = torch.randn([items[0], items[1]]) - # Could not guard on data-dependent expression Eq(u2, -1) - torch._check(items[2] != -1) - # Could not guard on data-dependent expression u2 >= 0 - torch._check(items[2] >= 0) - # Could not guard on data-dependent expression Eq(u1, u2) - return r.view(items[0], items[2]) - - M = M_v2 - with self.assertRaisesRegex( - error_type, - "The following call raised this error(.*\n)+" - f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+" - "To fix the error, insert one of the following checks before this call.*:\n" - f".*{re.escape('torch._check(items[2] == items[1])')}.*\n" - f".*{re.escape('torch._check(items[2] != items[1])')}(.*\n)+" - f".*{re.escape('(These suggested fixes were derived by replacing `u1` with items[1] or r.shape[1], `u2` with items[2] in Eq(u2, u1) and its negation.)')}", - ): - export(N(), (t,), strict=strict) - - class M_v3(torch.nn.Module): - def forward(self, t): - items = [t[i].item() for i in range(t.numel())] - r = torch.randn([items[0], items[1]]) - # Could not guard on data-dependent expression Eq(u2, -1) - torch._check(items[2] != -1) - # Could not guard on data-dependent expression u2 >= 0 - torch._check(items[2] >= 0) - # Could not guard on data-dependent expression Eq(u1, u2) - torch._check(items[2] == r.shape[1]) - return r.view(items[0], items[2]) - - M = M_v3 export(N(), (t,), strict=strict) def test_suggested_fixes_for_data_dependent_errors_puzzlers(self): @@ -4484,6 +4438,29 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): fixes=[], # nothing to fix! ) + def test_simple_unbacked_view(self): + class Foo(torch.nn.Module): + def forward(self, x): + u0 = x.item() + y = torch.empty(5, u0) + return y.view(u0, 5) # [5, u0] -> [u0, 5] + + ep = export(Foo(), (torch.tensor([9]),)) + self.assertEqual(ep.module()(torch.tensor([8])).size(0), 8) + self.assertEqual(ep.module()(torch.tensor([5])).size(0), 5) + + class Foov2(torch.nn.Module): + def forward(self, xs): + xsl = xs.tolist() + a, b = xsl + x = torch.zeros(a) + return x.reshape(b) + + xs = torch.tensor([4, 4]) + ep = export(Foov2(), (xs,)) + self.assertEqual(ep.module()(xs).size(0), 4) + self.assertEqual(ep.module()(torch.tensor([5, 5])).size(0), 5) + def test_no_suggested_fixes_for_data_dependent_errors(self): # suggested fixes for data-dependent errors only work in non-strict mode strict = False @@ -7549,22 +7526,19 @@ def forward(self, b_a_buffer, x): len([node for node in gm.graph.nodes if node.op == "placeholder"]), 2 ) - def test_check_is_size_error(self): + def test_no_check_is_size_error(self): class Module(torch.nn.Module): def forward(self, x): a = x.item() - # We cannot automatically infer a is a size here because view - # accepts -1 return torch.randn(24).view(a, 4) f = Module() - if is_non_strict_test(self._testMethodName): - error = torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode - else: - error = torch._dynamo.exc.UserError - error_msg = r"Could not guard on data-dependent expression" - with self.assertRaisesRegex(error, error_msg): - _ = export(f, (torch.tensor(6),)) + ep = export(f, (torch.tensor(6),)) + ep.module()(torch.tensor(6)) + with self.assertRaisesRegex( + RuntimeError, r"Runtime assertion failed for .* u.* 6" + ): + ep.module()(torch.tensor(5)) def test_is_non_negative_check_function(self): import sympy as sp @@ -13487,7 +13461,7 @@ def forward(self, x): node.target == torch.ops.aten._assert_scalar.default for node in ep.graph.nodes ].count(True) - self.assertEqual(num_asserts, 1) + self.assertEqual(num_asserts, 2) with self.assertRaises(RuntimeError): ep.module()(torch.randn(4, 2)) diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index e8339b789f5..dd2e6e27040 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -924,24 +924,29 @@ def infer_size(shape: ShapeType, numel: int) -> tuple[int, ...]: Infers the size of a dim with size -1, if it exists. Also checks that new shape is compatible with the number of elements. """ + from torch.fx.experimental.symbolic_shapes import definitely_true, guard_or_false + dim = None newsize = 1 for i, d in enumerate(shape): - if d == -1: + if guard_or_false(d == -1): torch._check(dim is None, lambda: "only one dimension can be inferred") dim = i - elif d >= 0: - newsize *= d else: - torch._check(False, lambda: f"invalid shape dimension {d}") + torch._check( + d >= 0, + lambda: ( + f"invalid shape dimension {d}. If this was symbolic, it was assumed to not be -1." + "If this was meant to be inferred, please explicitly pass in -1." + ), + ) + newsize *= d if dim is None: torch._check( numel == newsize, lambda: f"shape '{list(shape)}' is invalid for input of size {numel}", ) else: - from torch.fx.experimental.symbolic_shapes import definitely_true - torch._check( newsize != 0, lambda: ( diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index c9080a01ede..2e91b69fc7d 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -3717,7 +3717,8 @@ def repeat(a: Tensor, *repeat_shape) -> Tensor: def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorLikeType: - from torch.fx.experimental.symbolic_shapes import guard_size_oblivious, sym_eq + from torch._dynamo.exc import UserError, UserErrorType + from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true # Creates a valid shape shape = utils.extract_shape_from_varargs(shape, validate=False) @@ -3726,7 +3727,7 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL shape = utils.infer_size(shape, a.numel()) # Special-cases tensors with no elements - if guard_size_oblivious(a.numel() == 0): + if guard_or_false(a.numel() == 0): return as_strided(a, shape, utils.make_contiguous_strides_for(shape)) # Special-cases reshaping zero dim tensors @@ -3762,6 +3763,12 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL return torch.as_strided(a, [dim0, dim1], [dim1, 1]) # Handles general case: a 1+D tensor reshaped into a distinct 1+D shape + shape_numel = reduce(operator.mul, shape, 1) + torch._check( + a.numel() == shape_numel, + f"Could not reshape a tensor with shape {a.shape} as a tensor with shape {shape}!", + ) + deferred: list[Callable[[], bool]] = [] # NOTE [Reshape Algorithm] # This algorithm works by attempting to greedily construct the desired dimensions in @@ -3794,16 +3801,30 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL continue # Skips dimensions that are already the correct length - if guard_size_oblivious(length == a_.shape[idx]): + if guard_or_false(length == a_.shape[idx]): idx = idx + 1 continue # Gathers enough original dimensions such that this new dimension can be created # Note that this accumulation will terminate because we've verified a and the shape # specify the same number of elements above + def maybe_throw_dde(): + # NOTE: if you've hit a data-dependent error here, it's because in trying to accumulate input + # tensor dimensions to match the target shape (length), we've hit data-dependent errors testing + # divisibility (accum % length != 0), and have deferred raising them, in the hope that we'd + # figure out a valid reshape later in the loop. + # But we failed, either by running out of dimensions, or we couldn't figure out the strides, + # and we've decided to re-raise to either graph break out, or provide the exact guard so the user + # can torch._check() to avoid this. + for f in deferred: + f() + accum = a_.shape[idx] end = idx - while guard_size_oblivious(accum % length != 0): + while guard_or_true(accum % length != 0): + deferred.append(lambda: bool(accum % length != 0)) + if end == a_.ndim - 1: + maybe_throw_dde() end = end + 1 accum = accum * a_.shape[end] if end != idx: @@ -3817,13 +3838,15 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL if allow_copy: return prims.reshape(a, shape) + maybe_throw_dde() msg = f"Cannot view a tensor with shape {a.shape} and strides {a.stride()} as a tensor with shape {shape}!" raise ValueError(msg) a_ = flatten(a_, idx, end) - # Splits the (possibly flattened) dimension to create the desired dim length - if guard_size_oblivious(accum != length): + # Splits the (possibly flattened) dimension to create the desired dim length. + # guard_or_true is safe due to the tail unsqueeze routine. + if guard_or_true(accum != length): a_ = prims.split_dim(a_, idx, length) idx = idx + 1