From c93e34d7b5690ee77cd29dc7b28a8aa7f61d58aa Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 4 Apr 2025 16:26:00 +0000 Subject: [PATCH] Revert "bound sympy accuracy (#150383)" This reverts commit 1bc2b2b12ae1ddd27b0401a1baac3b8099b6fc50. Reverted https://github.com/pytorch/pytorch/pull/150383 on behalf of https://github.com/laithsakka due to big regression ([comment](https://github.com/pytorch/pytorch/pull/150383#issuecomment-2779227548)) --- test/export/test_export.py | 22 ---------------------- torch/utils/_sympy/value_ranges.py | 17 ----------------- 2 files changed, 39 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 5eefb67c14b..988e2fae81c 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -3105,28 +3105,6 @@ def forward(self, p_linear_weight, p_linear_bias, x): "dy - 6 = 6" not in exc.args[0] ) # don't suggest fix for non-root dim - @testing.expectedFailureLegacyExportNonStrict # FIXME constraint violation (guard: s0 - s0%8 != 1) - @testing.expectedFailureCppSerDes # FIXME data-dependent error (hinted: True, unhinted: s0 - s0%8 >= 0) - def test_bound_sympy_accuracy(self): - class Foo(torch.nn.Module): - def forward(self, x): - expr = x.shape[0] - (x.shape[0] % 8) - return torch.empty(expr) - - ep = export( - Foo(), - (torch.randn(13),), - dynamic_shapes={"x": (Dim("dim", min=2),)}, - ) - - (output,) = ep.graph.output_node().args[0] - sym_node = output.meta["val"].shape[0].node - vr = torch.utils._sympy.value_ranges.bound_sympy( - sym_node.expr, - sym_node.shape_env.var_to_range, - ) - self.assertEqual(vr.lower, 0) - @unittest.skip("See https://github.com/pytorch/pytorch/issues/135759") def test_keep_composite_ops_invalid(self): class Foo(torch.nn.Module): diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index 118959b8c4d..784f9e7ba05 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -1004,22 +1004,6 @@ class SymPyValueRangeAnalysis: return ValueRanges.increasing_map(x, TruncToFloat) -def _rewrite_for_value_range_analysis(expr: sympy.Expr): - """ - Sometimes accuracy of value range analysis can be improved - with simple rewriting rules. - """ - - # Rewrite X - X%Y to (X//Y) * Y. - x, y = sympy.Wild("x"), sympy.Wild("y") - expr = expr.replace( - x - torch.utils._sympy.functions.Mod(x, y), - torch.utils._sympy.functions.FloorDiv(x, y) * y, - ) - - return expr - - def bound_sympy( expr: sympy.Expr, ranges: Optional[dict[sympy.Symbol, ValueRanges]] = None ) -> ValueRanges: @@ -1063,7 +1047,6 @@ def bound_sympy( vr = ValueRanges.unknown() return vr - expr = _rewrite_for_value_range_analysis(expr) return sympy_interp( SymPyValueRangeAnalysis, ranges, expr, missing_handler=missing_handler )