Keep zero check be compatible with different sympy versions (#130729)

# Motivation
I found a difference between sympy 1.12 and 1.13.
```python
# for 1.12
>>> import sympy
>>> a = sympy.Number(0.0)
>>> a == 0
True
```
```python
# for 1.13
>>> import sympy
>>> a = sympy.Number(0.0)
>>> a == 0
False
```
The different behavior will impact the result of [safe_mul](6beec34b1c/torch/utils/_sympy/value_ranges.py (L521-L528)), resulting in an incorrect results when `a = sympy.Number(0.0)`, `b = inf` and the result is `nan` if sympy version is 1.13. (the expected result is **0**)
```python
def safe_mul(a, b):
    # Make unknown() * wrap(0.0) == wrap(0.0)
    if a == 0.0:
        return a
    elif b == 0.0:
        return b
    else:
        return a * b
```

In different sympy versions, `sympy.Number(0)` always has the same behavior that equals to 0.0.
```python
>>> import sympy
>>> a = sympy.Number(0)
>>> a == 0.0
True # for different sympy versions
```
So, use 0.0 when checking zero in safe_mul to keep compatible with different sympy versions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130729
Approved by: https://github.com/lezcano, https://github.com/EikanWang
This commit is contained in:
Yu, Guangye 2024-07-16 01:16:18 +00:00 committed by PyTorch MergeBot
parent fedae41c57
commit 096dc444ce
2 changed files with 7 additions and 3 deletions

View File

@ -241,6 +241,10 @@ class TestValueRanges(TestCase):
ValueRangeAnalysis.mul(ValueRanges.wrap(0), ValueRanges.unknown()),
ValueRanges.wrap(0),
)
self.assertEqual(
ValueRangeAnalysis.mul(ValueRanges.wrap(0.0), ValueRanges.unknown()),
ValueRanges.wrap(0.0),
)
@parametrize("fn", UNARY_BOOL_OPS)
def test_unary_bool_ref_range(self, fn):

View File

@ -519,10 +519,10 @@ class SymPyValueRangeAnalysis:
return cls.and_(a, b)
def safe_mul(a, b):
# Make unknown() * wrap(0) == wrap(0)
if a == 0:
# Make unknown() * wrap(0.0) == wrap(0.0)
if a == 0.0:
return a
elif b == 0:
elif b == 0.0:
return b
else:
return a * b