fix mypi in utils/_sympy/functions.py (#136339)

Signed-off-by: Bob Ren <bobren@fb.com>

Turns out older versions of python, in particular 3.8 shows errors that 3.12 doesn't. For posterity these are the steps I took to reproduce:

```
conda create -n py38 python=3.8
conda activate py38
pip install -r requirements.txt
lintrunner init
dmypy restart && lintrunner --all-files --take MYPY
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136339
Approved by: https://github.com/Skylion007
ghstack dependencies: #136205
This commit is contained in:
Bob Ren 2024-09-19 16:41:00 -07:00 committed by PyTorch MergeBot
parent f53a0f9cc1
commit 7f9c06462f

View File

@ -89,10 +89,10 @@ __all__ = [
]
def _keep_float(f: Callable[..., _T]) -> Callable[..., sympy.Float]:
def _keep_float(f: Callable[..., _T]) -> Callable[..., Union[_T, sympy.Float]]:
@functools.wraps(f)
def inner(*args: Any) -> Union[_T, sympy.Float]:
r = f(*args)
r: Union[_T, sympy.Float] = f(*args)
if any(isinstance(a, sympy.Float) for a in args) and not isinstance(
r, sympy.Float
):
@ -140,7 +140,7 @@ def simple_floordiv_gcd(p: sympy.Basic, q: sympy.Basic) -> sympy.Basic:
return functools.reduce(math.gcd, integer_factors)
gcd: int = math.gcd(integer_factor(p), integer_factor(q))
p, q = p / gcd, q / gcd
p, q = p / gcd, q / gcd # type: ignore[operator, assignment] # remove in py3.12
base_splits: List[Tuple[sympy.Basic, ...]] = list(
map(sympy.Mul.make_args, sympy.Add.make_args(p))
@ -148,8 +148,8 @@ def simple_floordiv_gcd(p: sympy.Basic, q: sympy.Basic) -> sympy.Basic:
divisor_split: Tuple[sympy.Basic, ...] = sympy.Mul.make_args(q)
for x in divisor_split:
if all(x in base_split for base_split in base_splits):
gcd = gcd * x
return gcd
gcd = gcd * x # type: ignore[operator] # remove in py3.12
return gcd # type: ignore[return-value] # remove in py3.12
# It would be nice to have assertions on whether or not inputs is_integer
@ -191,7 +191,7 @@ class FloorDiv(sympy.Function):
def divisor(self) -> sympy.Basic:
return self.args[1]
def _sympystr(self, printer: sympy.printing.printer.Printer) -> str:
def _sympystr(self, printer: sympy.printing.StrPrinter) -> str:
base = printer.parenthesize(self.base, self.precedence)
divisor = printer.parenthesize(self.divisor, self.precedence)
return f"({base}//{divisor})"
@ -199,7 +199,9 @@ class FloorDiv(sympy.Function):
# Automatic evaluation.
# https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval
@classmethod
def eval(cls, base: sympy.Basic, divisor: sympy.Basic) -> Union[sympy.Basic, None]:
def eval(
cls, base: sympy.Integer, divisor: sympy.Integer
) -> Union[sympy.Basic, None]:
# python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full
# Assert triggered by inequality solver
# assert base.is_integer, base
@ -281,7 +283,7 @@ class ModularIndexing(sympy.Function):
@classmethod
def eval(
cls, base: sympy.Basic, divisor: sympy.Basic, modulus: sympy.Basic
cls, base: sympy.Integer, divisor: sympy.Integer, modulus: sympy.Integer
) -> Optional[sympy.Basic]:
if base == 0 or modulus == 1:
return sympy.Integer(0)
@ -306,7 +308,7 @@ class ModularIndexing(sympy.Function):
pass # https://github.com/pytorch/pytorch/issues/108276
if isinstance(base, sympy.Add):
new_terms: List[sympy.Basic] = []
new_terms: List[sympy.Integer] = []
all_positive: bool = True
for term in base.args:
if sympy.gcd(term, modulus * divisor) != modulus * divisor:
@ -1156,7 +1158,7 @@ class Identity(sympy.Function):
Prevents expansion and other optimizations
"""
def __repr__(self):
def __repr__(self): # type: ignore[override]
return f"Identity({self.args[0]})"
def _eval_is_real(self):