mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "bound sympy accuracy (#150383)"
This reverts commit 1bc2b2b12a.
Reverted https://github.com/pytorch/pytorch/pull/150383 on behalf of https://github.com/laithsakka due to big regression ([comment](https://github.com/pytorch/pytorch/pull/150383#issuecomment-2779227548))
This commit is contained in:
parent
f443035f10
commit
c93e34d7b5
|
|
@ -3105,28 +3105,6 @@ def forward(self, p_linear_weight, p_linear_bias, x):
|
||||||
"dy - 6 = 6" not in exc.args[0]
|
"dy - 6 = 6" not in exc.args[0]
|
||||||
) # don't suggest fix for non-root dim
|
) # don't suggest fix for non-root dim
|
||||||
|
|
||||||
@testing.expectedFailureLegacyExportNonStrict # FIXME constraint violation (guard: s0 - s0%8 != 1)
|
|
||||||
@testing.expectedFailureCppSerDes # FIXME data-dependent error (hinted: True, unhinted: s0 - s0%8 >= 0)
|
|
||||||
def test_bound_sympy_accuracy(self):
|
|
||||||
class Foo(torch.nn.Module):
|
|
||||||
def forward(self, x):
|
|
||||||
expr = x.shape[0] - (x.shape[0] % 8)
|
|
||||||
return torch.empty(expr)
|
|
||||||
|
|
||||||
ep = export(
|
|
||||||
Foo(),
|
|
||||||
(torch.randn(13),),
|
|
||||||
dynamic_shapes={"x": (Dim("dim", min=2),)},
|
|
||||||
)
|
|
||||||
|
|
||||||
(output,) = ep.graph.output_node().args[0]
|
|
||||||
sym_node = output.meta["val"].shape[0].node
|
|
||||||
vr = torch.utils._sympy.value_ranges.bound_sympy(
|
|
||||||
sym_node.expr,
|
|
||||||
sym_node.shape_env.var_to_range,
|
|
||||||
)
|
|
||||||
self.assertEqual(vr.lower, 0)
|
|
||||||
|
|
||||||
@unittest.skip("See https://github.com/pytorch/pytorch/issues/135759")
|
@unittest.skip("See https://github.com/pytorch/pytorch/issues/135759")
|
||||||
def test_keep_composite_ops_invalid(self):
|
def test_keep_composite_ops_invalid(self):
|
||||||
class Foo(torch.nn.Module):
|
class Foo(torch.nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -1004,22 +1004,6 @@ class SymPyValueRangeAnalysis:
|
||||||
return ValueRanges.increasing_map(x, TruncToFloat)
|
return ValueRanges.increasing_map(x, TruncToFloat)
|
||||||
|
|
||||||
|
|
||||||
def _rewrite_for_value_range_analysis(expr: sympy.Expr):
|
|
||||||
"""
|
|
||||||
Sometimes accuracy of value range analysis can be improved
|
|
||||||
with simple rewriting rules.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Rewrite X - X%Y to (X//Y) * Y.
|
|
||||||
x, y = sympy.Wild("x"), sympy.Wild("y")
|
|
||||||
expr = expr.replace(
|
|
||||||
x - torch.utils._sympy.functions.Mod(x, y),
|
|
||||||
torch.utils._sympy.functions.FloorDiv(x, y) * y,
|
|
||||||
)
|
|
||||||
|
|
||||||
return expr
|
|
||||||
|
|
||||||
|
|
||||||
def bound_sympy(
|
def bound_sympy(
|
||||||
expr: sympy.Expr, ranges: Optional[dict[sympy.Symbol, ValueRanges]] = None
|
expr: sympy.Expr, ranges: Optional[dict[sympy.Symbol, ValueRanges]] = None
|
||||||
) -> ValueRanges:
|
) -> ValueRanges:
|
||||||
|
|
@ -1063,7 +1047,6 @@ def bound_sympy(
|
||||||
vr = ValueRanges.unknown()
|
vr = ValueRanges.unknown()
|
||||||
return vr
|
return vr
|
||||||
|
|
||||||
expr = _rewrite_for_value_range_analysis(expr)
|
|
||||||
return sympy_interp(
|
return sympy_interp(
|
||||||
SymPyValueRangeAnalysis, ranges, expr, missing_handler=missing_handler
|
SymPyValueRangeAnalysis, ranges, expr, missing_handler=missing_handler
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user