Revert "bound sympy accuracy (#150383)"

This reverts commit 1bc2b2b12a.

Reverted https://github.com/pytorch/pytorch/pull/150383 on behalf of https://github.com/laithsakka due to big regression ([comment](https://github.com/pytorch/pytorch/pull/150383#issuecomment-2779227548))
This commit is contained in:
PyTorch MergeBot 2025-04-04 16:26:00 +00:00
parent f443035f10
commit c93e34d7b5
2 changed files with 0 additions and 39 deletions

View File

@ -3105,28 +3105,6 @@ def forward(self, p_linear_weight, p_linear_bias, x):
"dy - 6 = 6" not in exc.args[0] "dy - 6 = 6" not in exc.args[0]
) # don't suggest fix for non-root dim ) # don't suggest fix for non-root dim
@testing.expectedFailureLegacyExportNonStrict # FIXME constraint violation (guard: s0 - s0%8 != 1)
@testing.expectedFailureCppSerDes # FIXME data-dependent error (hinted: True, unhinted: s0 - s0%8 >= 0)
def test_bound_sympy_accuracy(self):
class Foo(torch.nn.Module):
def forward(self, x):
expr = x.shape[0] - (x.shape[0] % 8)
return torch.empty(expr)
ep = export(
Foo(),
(torch.randn(13),),
dynamic_shapes={"x": (Dim("dim", min=2),)},
)
(output,) = ep.graph.output_node().args[0]
sym_node = output.meta["val"].shape[0].node
vr = torch.utils._sympy.value_ranges.bound_sympy(
sym_node.expr,
sym_node.shape_env.var_to_range,
)
self.assertEqual(vr.lower, 0)
@unittest.skip("See https://github.com/pytorch/pytorch/issues/135759") @unittest.skip("See https://github.com/pytorch/pytorch/issues/135759")
def test_keep_composite_ops_invalid(self): def test_keep_composite_ops_invalid(self):
class Foo(torch.nn.Module): class Foo(torch.nn.Module):

View File

@ -1004,22 +1004,6 @@ class SymPyValueRangeAnalysis:
return ValueRanges.increasing_map(x, TruncToFloat) return ValueRanges.increasing_map(x, TruncToFloat)
def _rewrite_for_value_range_analysis(expr: sympy.Expr):
"""
Sometimes accuracy of value range analysis can be improved
with simple rewriting rules.
"""
# Rewrite X - X%Y to (X//Y) * Y.
x, y = sympy.Wild("x"), sympy.Wild("y")
expr = expr.replace(
x - torch.utils._sympy.functions.Mod(x, y),
torch.utils._sympy.functions.FloorDiv(x, y) * y,
)
return expr
def bound_sympy( def bound_sympy(
expr: sympy.Expr, ranges: Optional[dict[sympy.Symbol, ValueRanges]] = None expr: sympy.Expr, ranges: Optional[dict[sympy.Symbol, ValueRanges]] = None
) -> ValueRanges: ) -> ValueRanges:
@ -1063,7 +1047,6 @@ def bound_sympy(
vr = ValueRanges.unknown() vr = ValueRanges.unknown()
return vr return vr
expr = _rewrite_for_value_range_analysis(expr)
return sympy_interp( return sympy_interp(
SymPyValueRangeAnalysis, ranges, expr, missing_handler=missing_handler SymPyValueRangeAnalysis, ranges, expr, missing_handler=missing_handler
) )