mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Keep zero check be compatible with different sympy versions (#130729)
# Motivation
I found a difference between sympy 1.12 and 1.13.
```python
# for 1.12
>>> import sympy
>>> a = sympy.Number(0.0)
>>> a == 0
True
```
```python
# for 1.13
>>> import sympy
>>> a = sympy.Number(0.0)
>>> a == 0
False
```
The different behavior will impact the result of [safe_mul](6beec34b1c/torch/utils/_sympy/value_ranges.py (L521-L528)), resulting in an incorrect results when `a = sympy.Number(0.0)`, `b = inf` and the result is `nan` if sympy version is 1.13. (the expected result is **0**)
```python
def safe_mul(a, b):
# Make unknown() * wrap(0.0) == wrap(0.0)
if a == 0.0:
return a
elif b == 0.0:
return b
else:
return a * b
```
In different sympy versions, `sympy.Number(0)` always has the same behavior that equals to 0.0.
```python
>>> import sympy
>>> a = sympy.Number(0)
>>> a == 0.0
True # for different sympy versions
```
So, use 0.0 when checking zero in safe_mul to keep compatible with different sympy versions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130729
Approved by: https://github.com/lezcano, https://github.com/EikanWang
This commit is contained in:
parent
fedae41c57
commit
096dc444ce
|
|
@ -241,6 +241,10 @@ class TestValueRanges(TestCase):
|
|||
ValueRangeAnalysis.mul(ValueRanges.wrap(0), ValueRanges.unknown()),
|
||||
ValueRanges.wrap(0),
|
||||
)
|
||||
self.assertEqual(
|
||||
ValueRangeAnalysis.mul(ValueRanges.wrap(0.0), ValueRanges.unknown()),
|
||||
ValueRanges.wrap(0.0),
|
||||
)
|
||||
|
||||
@parametrize("fn", UNARY_BOOL_OPS)
|
||||
def test_unary_bool_ref_range(self, fn):
|
||||
|
|
|
|||
|
|
@ -519,10 +519,10 @@ class SymPyValueRangeAnalysis:
|
|||
return cls.and_(a, b)
|
||||
|
||||
def safe_mul(a, b):
|
||||
# Make unknown() * wrap(0) == wrap(0)
|
||||
if a == 0:
|
||||
# Make unknown() * wrap(0.0) == wrap(0.0)
|
||||
if a == 0.0:
|
||||
return a
|
||||
elif b == 0:
|
||||
elif b == 0.0:
|
||||
return b
|
||||
else:
|
||||
return a * b
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user