mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
fix mypi in utils/_sympy/functions.py (#136339)
Signed-off-by: Bob Ren <bobren@fb.com> Turns out older versions of python, in particular 3.8 shows errors that 3.12 doesn't. For posterity these are the steps I took to reproduce: ``` conda create -n py38 python=3.8 conda activate py38 pip install -r requirements.txt lintrunner init dmypy restart && lintrunner --all-files --take MYPY ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/136339 Approved by: https://github.com/Skylion007 ghstack dependencies: #136205
This commit is contained in:
parent
f53a0f9cc1
commit
7f9c06462f
|
|
@ -89,10 +89,10 @@ __all__ = [
|
|||
]
|
||||
|
||||
|
||||
def _keep_float(f: Callable[..., _T]) -> Callable[..., sympy.Float]:
|
||||
def _keep_float(f: Callable[..., _T]) -> Callable[..., Union[_T, sympy.Float]]:
|
||||
@functools.wraps(f)
|
||||
def inner(*args: Any) -> Union[_T, sympy.Float]:
|
||||
r = f(*args)
|
||||
r: Union[_T, sympy.Float] = f(*args)
|
||||
if any(isinstance(a, sympy.Float) for a in args) and not isinstance(
|
||||
r, sympy.Float
|
||||
):
|
||||
|
|
@ -140,7 +140,7 @@ def simple_floordiv_gcd(p: sympy.Basic, q: sympy.Basic) -> sympy.Basic:
|
|||
return functools.reduce(math.gcd, integer_factors)
|
||||
|
||||
gcd: int = math.gcd(integer_factor(p), integer_factor(q))
|
||||
p, q = p / gcd, q / gcd
|
||||
p, q = p / gcd, q / gcd # type: ignore[operator, assignment] # remove in py3.12
|
||||
|
||||
base_splits: List[Tuple[sympy.Basic, ...]] = list(
|
||||
map(sympy.Mul.make_args, sympy.Add.make_args(p))
|
||||
|
|
@ -148,8 +148,8 @@ def simple_floordiv_gcd(p: sympy.Basic, q: sympy.Basic) -> sympy.Basic:
|
|||
divisor_split: Tuple[sympy.Basic, ...] = sympy.Mul.make_args(q)
|
||||
for x in divisor_split:
|
||||
if all(x in base_split for base_split in base_splits):
|
||||
gcd = gcd * x
|
||||
return gcd
|
||||
gcd = gcd * x # type: ignore[operator] # remove in py3.12
|
||||
return gcd # type: ignore[return-value] # remove in py3.12
|
||||
|
||||
|
||||
# It would be nice to have assertions on whether or not inputs is_integer
|
||||
|
|
@ -191,7 +191,7 @@ class FloorDiv(sympy.Function):
|
|||
def divisor(self) -> sympy.Basic:
|
||||
return self.args[1]
|
||||
|
||||
def _sympystr(self, printer: sympy.printing.printer.Printer) -> str:
|
||||
def _sympystr(self, printer: sympy.printing.StrPrinter) -> str:
|
||||
base = printer.parenthesize(self.base, self.precedence)
|
||||
divisor = printer.parenthesize(self.divisor, self.precedence)
|
||||
return f"({base}//{divisor})"
|
||||
|
|
@ -199,7 +199,9 @@ class FloorDiv(sympy.Function):
|
|||
# Automatic evaluation.
|
||||
# https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval
|
||||
@classmethod
|
||||
def eval(cls, base: sympy.Basic, divisor: sympy.Basic) -> Union[sympy.Basic, None]:
|
||||
def eval(
|
||||
cls, base: sympy.Integer, divisor: sympy.Integer
|
||||
) -> Union[sympy.Basic, None]:
|
||||
# python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full
|
||||
# Assert triggered by inequality solver
|
||||
# assert base.is_integer, base
|
||||
|
|
@ -281,7 +283,7 @@ class ModularIndexing(sympy.Function):
|
|||
|
||||
@classmethod
|
||||
def eval(
|
||||
cls, base: sympy.Basic, divisor: sympy.Basic, modulus: sympy.Basic
|
||||
cls, base: sympy.Integer, divisor: sympy.Integer, modulus: sympy.Integer
|
||||
) -> Optional[sympy.Basic]:
|
||||
if base == 0 or modulus == 1:
|
||||
return sympy.Integer(0)
|
||||
|
|
@ -306,7 +308,7 @@ class ModularIndexing(sympy.Function):
|
|||
pass # https://github.com/pytorch/pytorch/issues/108276
|
||||
|
||||
if isinstance(base, sympy.Add):
|
||||
new_terms: List[sympy.Basic] = []
|
||||
new_terms: List[sympy.Integer] = []
|
||||
all_positive: bool = True
|
||||
for term in base.args:
|
||||
if sympy.gcd(term, modulus * divisor) != modulus * divisor:
|
||||
|
|
@ -1156,7 +1158,7 @@ class Identity(sympy.Function):
|
|||
Prevents expansion and other optimizations
|
||||
"""
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self): # type: ignore[override]
|
||||
return f"Identity({self.args[0]})"
|
||||
|
||||
def _eval_is_real(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user