mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[fx] rewrite FloorDiv to match Python better (#90906)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90906 Approved by: https://github.com/ezyang
This commit is contained in:
parent
5e0d3458eb
commit
d13207c7ad
|
|
@ -20,7 +20,7 @@ import os
|
|||
from torch.utils._pytree import tree_map
|
||||
from torch.fx.experimental import symbolic_shapes
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv, sym_float, guard_int, SymNode, sym_sqrt, sym_int, to_node
|
||||
from torch.fx.experimental.symbolic_shapes import FloorDiv, ShapeEnv, sym_float, guard_int, SymNode, sym_sqrt, sym_int, to_node
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch import SymInt
|
||||
|
||||
|
|
@ -469,15 +469,6 @@ if COLLECT_EXPECT:
|
|||
atexit.register(print_seen)
|
||||
|
||||
expected_failure_sym_magic_methods = {
|
||||
('floordiv', 'SymFloat', 'float'), # Cannot convert complex to float
|
||||
('floordiv', 'float', 'SymFloat'), # Cannot convert complex to float
|
||||
('floordiv', 'SymFloat', 'SymFloat'), # Cannot convert complex to float
|
||||
('floordiv', 'SymFloat', 'int'), # Scalars are not close!
|
||||
('floordiv', 'float', 'SymInt'), # Scalars are not close!
|
||||
('floordiv', 'SymFloat', 'SymInt'), # Scalars are not close!
|
||||
('floordiv', 'SymInt', 'float'), # Cannot convert complex to float
|
||||
('floordiv', 'int', 'SymFloat'), # Cannot convert complex to float
|
||||
('floordiv', 'SymInt', 'SymFloat'), # Cannot convert complex to float
|
||||
}
|
||||
|
||||
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
|
||||
|
|
@ -598,5 +589,158 @@ class TestSymNumberMagicMethods(TestCase):
|
|||
|
||||
instantiate_parametrized_tests(TestSymNumberMagicMethods)
|
||||
|
||||
# Checks that we correctly implement Python floordiv semantics with FloorDiv.
|
||||
# See NOTE [ SymPy eval and assumptions ]
|
||||
class TestFloorDiv(TestCase):
|
||||
@skipIfNoSympy
|
||||
def test_floordiv(self):
|
||||
values = (
|
||||
# complex is parsed as SymPy Add by FloorDiv (even when created with
|
||||
# the complex constructor) and complex is not supported by Python
|
||||
# floordiv.
|
||||
1.5 + 2.5j,
|
||||
# These test type-promotion and flooring behavior:
|
||||
2.9,
|
||||
2.5,
|
||||
2.1,
|
||||
2.0,
|
||||
7,
|
||||
# These make sure we handle various short-circuits properly:
|
||||
1.0,
|
||||
0.0,
|
||||
1,
|
||||
0,
|
||||
# Note: booleans cannot be passed directly to FloorDiv and cannot
|
||||
# be directly used in arithmetic exprs in SymPy, but we make an
|
||||
# attempt to test them anyway.
|
||||
True,
|
||||
False,
|
||||
)
|
||||
|
||||
# This helps catch issues when flooring.
|
||||
neg_values = tuple(-x for x in values)
|
||||
|
||||
def python_func(x, y):
|
||||
return x // y
|
||||
|
||||
def torch_func(x, y):
|
||||
# Note: we fully evaluate here since FloorDiv might not always do
|
||||
# that.
|
||||
shape_env = ShapeEnv()
|
||||
return shape_env.evaluate_expr(FloorDiv(x, y))
|
||||
|
||||
def other_func(func, x, y):
|
||||
if func is python_func:
|
||||
return torch_func(x, y)
|
||||
else:
|
||||
return python_func(x, y)
|
||||
|
||||
funcs = (
|
||||
python_func,
|
||||
torch_func,
|
||||
)
|
||||
|
||||
# We do not check error messages on the Python side to avoid depending
|
||||
# on an interpreter version.
|
||||
for func, (x, y) in itertools.product(funcs, itertools.chain(
|
||||
itertools.product(values, values),
|
||||
itertools.product(neg_values, values),
|
||||
itertools.product(values, neg_values),
|
||||
itertools.product(neg_values, neg_values),
|
||||
)):
|
||||
def assert_unsupported_error(func, x, y):
|
||||
if func is torch_func:
|
||||
# makes sure we use the SymPy types
|
||||
x = sympy.sympify(x)
|
||||
y = sympy.sympify(y)
|
||||
err = (
|
||||
rf"unsupported operand type\(s\) for //: "
|
||||
rf"'{type(x).__name__}' and '{type(y).__name__}'"
|
||||
rf", expected integer or real"
|
||||
)
|
||||
else:
|
||||
err = ""
|
||||
self.assertRaisesRegex(TypeError, err, lambda: func(x, y))
|
||||
|
||||
if type(x) is complex or type(y) is complex:
|
||||
# complex is not supported by floordiv
|
||||
assert_unsupported_error(func, x, y)
|
||||
elif (type(x) is bool or type(y) is bool) and func is torch_func:
|
||||
# bools are not supported in arithmetic exprs in SymPy
|
||||
assert_unsupported_error(func, x, y)
|
||||
elif (type(x) is bool or type(y) is bool) and y != 0:
|
||||
# test bools against SymPy ints unless it's a div by zero
|
||||
int_x = int(x) if type(x) is bool else x
|
||||
int_y = int(y) if type(y) is bool else y
|
||||
self.assertEqual(func(x, y), other_func(func, int_x, int_y))
|
||||
elif y == 0:
|
||||
# div by zero
|
||||
if func is torch_func:
|
||||
err = "division by zero"
|
||||
else:
|
||||
err = ""
|
||||
self.assertRaisesRegex(ZeroDivisionError, err, lambda: func(x, y))
|
||||
else:
|
||||
# otherwise, compare results
|
||||
self.assertEqual(func(x, y), other_func(func, x, y))
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_floordiv_simplify(self):
|
||||
# Checks that we eval exprs without free vars no matter which
|
||||
# simplify/eval func is called.
|
||||
expr = FloorDiv(6.28, (FloorDiv(6.28, 3.14)))
|
||||
shape_env = ShapeEnv()
|
||||
|
||||
# All these should return the same result.
|
||||
self.assertEqual(expr, 3) # fully eval'd automatically
|
||||
self.assertEqual(expr.doit(deep=False), 3)
|
||||
self.assertEqual(expr.doit(deep=True), 3)
|
||||
self.assertEqual(sympy.simplify(expr), 3)
|
||||
self.assertEqual(shape_env.simplify(expr), 3)
|
||||
self.assertEqual(shape_env.evaluate_expr(expr), 3)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_floordiv_assumptions(self):
|
||||
# We define two Symbols (with different names) for each type to make
|
||||
# sure the behavior is consistent regardless of whether both arguments
|
||||
# are the same object or not.
|
||||
cases = (
|
||||
sympy.Symbol("i1", integer=True),
|
||||
sympy.Symbol("i2", integer=True),
|
||||
sympy.Symbol("r1", real=True),
|
||||
sympy.Symbol("r2", real=True),
|
||||
sympy.Symbol("c1", complex=True, real=False, integer=False),
|
||||
sympy.Symbol("c2", complex=True, real=False, integer=False),
|
||||
sympy.Symbol("s1"),
|
||||
sympy.Symbol("s2"),
|
||||
)
|
||||
|
||||
for base, divisor in itertools.product(cases, repeat=2):
|
||||
def op():
|
||||
return FloorDiv(base, divisor)
|
||||
|
||||
def is_complex(x):
|
||||
return x.is_integer is False and x.is_real is False and x.is_complex
|
||||
|
||||
if is_complex(base) or is_complex(divisor):
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
(r"unsupported operand type\(s\) for //: 'Symbol' and 'Symbol',"
|
||||
r" expected integer or real"),
|
||||
op)
|
||||
continue
|
||||
|
||||
op = op()
|
||||
|
||||
# In regular Python, x//x == 1.0 if x is a float, but FloorDiv
|
||||
# always returns an integer 1 when both args are the same object.
|
||||
# This even works for Symbols with no assumptions specified.
|
||||
if base is divisor or (base.is_integer and divisor.is_integer):
|
||||
self.assertTrue(op.is_integer)
|
||||
self.assertTrue(op.is_real)
|
||||
else:
|
||||
self.assertEqual(op.is_integer, None)
|
||||
self.assertTrue(op.is_real)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ try:
|
|||
import sympy # type: ignore[import]
|
||||
from sympy.printing.precedence import precedence # type: ignore[import] # noqa: F401
|
||||
from sympy.printing.str import StrPrinter # type: ignore[import]
|
||||
from sympy.core.logic import fuzzy_and, fuzzy_or # type: ignore[import]
|
||||
HAS_SYMPY = True
|
||||
except ImportError:
|
||||
HAS_SYMPY = False
|
||||
|
|
@ -210,6 +211,19 @@ class SymNode:
|
|||
|
||||
|
||||
if HAS_SYMPY:
|
||||
# NOTE [ SymPy eval and assumptions ]
|
||||
# In eval, we only return values in cases where we always want to evaluate.
|
||||
# In other cases, the result will just be FloorDiv(a, b), which needs to be
|
||||
# evaluated later if necessary.
|
||||
#
|
||||
# We also define is_real=True and provide _eval_* methods to make the SymPy
|
||||
# assumptions system aware of Python floordiv semantics. For instance, this
|
||||
# ensures that correct assumptions are propagated when working with SymPy
|
||||
# Symbols. Two integer Symbols should return an integer result.
|
||||
#
|
||||
# https://peps.python.org/pep-0238/#semantics-of-floor-division
|
||||
# https://docs.sympy.org/latest/guides/assumptions.html#implementing-assumptions-handlers
|
||||
# https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval
|
||||
class FloorDiv(sympy.Function):
|
||||
"""
|
||||
We maintain this so that:
|
||||
|
|
@ -219,21 +233,61 @@ if HAS_SYMPY:
|
|||
nargs = (2,)
|
||||
precedence = 50 # precedence of mul # noqa: F811
|
||||
|
||||
def _sympystr(self, printer):
|
||||
lhs = self.args[0]
|
||||
rhs = self.args[1]
|
||||
lhs_str = printer.parenthesize(lhs, self.precedence)
|
||||
rhs_str = printer.parenthesize(rhs, self.precedence)
|
||||
return f"{lhs_str}//{rhs_str}"
|
||||
# Default return type. For instance, this applies when both arguments
|
||||
# are Symbols without any assumptions.
|
||||
# See NOTE [ SymPy eval and assumptions ]
|
||||
is_real = True
|
||||
|
||||
@property
|
||||
def base(self):
|
||||
return self.args[0]
|
||||
|
||||
@property
|
||||
def divisor(self):
|
||||
return self.args[1]
|
||||
|
||||
def _sympystr(self, printer):
|
||||
base = printer.parenthesize(self.base, self.precedence)
|
||||
divisor = printer.parenthesize(self.divisor, self.precedence)
|
||||
return f"{base}//{divisor}"
|
||||
|
||||
# Assumptions based on argument types.
|
||||
# See NOTE [ SymPy eval and assumptions ]
|
||||
def _eval_is_real(self):
|
||||
return fuzzy_or([self.base.is_real, self.divisor.is_real])
|
||||
|
||||
def _eval_is_integer(self):
|
||||
return fuzzy_and([self.base.is_integer, self.divisor.is_integer])
|
||||
|
||||
# Automatic evaluation.
|
||||
# See NOTE [ SymPy eval and assumptions ]
|
||||
@classmethod
|
||||
def eval(cls, base, divisor):
|
||||
if base == 0:
|
||||
return sympy.Integer(0)
|
||||
def check_supported_type(x):
|
||||
if (x.is_integer is False and x.is_real is False and x.is_complex) or x.is_Boolean:
|
||||
raise TypeError(
|
||||
f"unsupported operand type(s) for //: "
|
||||
f"'{type(base).__name__}' and '{type(divisor).__name__}'"
|
||||
f", expected integer or real")
|
||||
|
||||
check_supported_type(base)
|
||||
check_supported_type(divisor)
|
||||
|
||||
# We don't provide the same error message as in Python because SymPy
|
||||
# makes it difficult to check the types.
|
||||
if divisor.is_zero:
|
||||
raise ZeroDivisionError("division by zero")
|
||||
|
||||
# We don't cast the return type as in Python because SymPy makes it
|
||||
# difficult to check the types.
|
||||
if base.is_zero:
|
||||
return sympy.S.Zero
|
||||
if divisor == 1:
|
||||
return base
|
||||
return sympy.floor(base)
|
||||
if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
|
||||
return base // divisor
|
||||
if isinstance(base, (sympy.Integer, sympy.Float)) and isinstance(divisor, (sympy.Integer, sympy.Float)):
|
||||
return sympy.floor(base / divisor)
|
||||
if isinstance(base, FloorDiv):
|
||||
return FloorDiv(base.args[0], base.args[1] * divisor)
|
||||
|
||||
|
|
@ -243,6 +297,7 @@ if HAS_SYMPY:
|
|||
sympy.simplify(base / gcd), sympy.simplify(divisor / gcd)
|
||||
)
|
||||
|
||||
|
||||
# Methods that have a `__foo__` as well as `__rfoo__`
|
||||
reflectable_magic_methods = {
|
||||
'add': lambda a, b: a + b,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user