[fx] rewrite FloorDiv to match Python better (#90906)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90906
Approved by: https://github.com/ezyang
This commit is contained in:
Nikita Karetnikov 2023-01-16 10:52:35 +01:00 committed by PyTorch MergeBot
parent 5e0d3458eb
commit d13207c7ad
2 changed files with 218 additions and 19 deletions

View File

@ -20,7 +20,7 @@ import os
from torch.utils._pytree import tree_map
from torch.fx.experimental import symbolic_shapes
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import ShapeEnv, sym_float, guard_int, SymNode, sym_sqrt, sym_int, to_node
from torch.fx.experimental.symbolic_shapes import FloorDiv, ShapeEnv, sym_float, guard_int, SymNode, sym_sqrt, sym_int, to_node
from torch.utils._python_dispatch import TorchDispatchMode
from torch import SymInt
@ -469,15 +469,6 @@ if COLLECT_EXPECT:
atexit.register(print_seen)
expected_failure_sym_magic_methods = {
('floordiv', 'SymFloat', 'float'), # Cannot convert complex to float
('floordiv', 'float', 'SymFloat'), # Cannot convert complex to float
('floordiv', 'SymFloat', 'SymFloat'), # Cannot convert complex to float
('floordiv', 'SymFloat', 'int'), # Scalars are not close!
('floordiv', 'float', 'SymInt'), # Scalars are not close!
('floordiv', 'SymFloat', 'SymInt'), # Scalars are not close!
('floordiv', 'SymInt', 'float'), # Cannot convert complex to float
('floordiv', 'int', 'SymFloat'), # Cannot convert complex to float
('floordiv', 'SymInt', 'SymFloat'), # Cannot convert complex to float
}
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
@ -598,5 +589,158 @@ class TestSymNumberMagicMethods(TestCase):
instantiate_parametrized_tests(TestSymNumberMagicMethods)
# Checks that we correctly implement Python floordiv semantics with FloorDiv.
# See NOTE [ SymPy eval and assumptions ]
class TestFloorDiv(TestCase):
@skipIfNoSympy
def test_floordiv(self):
values = (
# complex is parsed as SymPy Add by FloorDiv (even when created with
# the complex constructor) and complex is not supported by Python
# floordiv.
1.5 + 2.5j,
# These test type-promotion and flooring behavior:
2.9,
2.5,
2.1,
2.0,
7,
# These make sure we handle various short-circuits properly:
1.0,
0.0,
1,
0,
# Note: booleans cannot be passed directly to FloorDiv and cannot
# be directly used in arithmetic exprs in SymPy, but we make an
# attempt to test them anyway.
True,
False,
)
# This helps catch issues when flooring.
neg_values = tuple(-x for x in values)
def python_func(x, y):
return x // y
def torch_func(x, y):
# Note: we fully evaluate here since FloorDiv might not always do
# that.
shape_env = ShapeEnv()
return shape_env.evaluate_expr(FloorDiv(x, y))
def other_func(func, x, y):
if func is python_func:
return torch_func(x, y)
else:
return python_func(x, y)
funcs = (
python_func,
torch_func,
)
# We do not check error messages on the Python side to avoid depending
# on an interpreter version.
for func, (x, y) in itertools.product(funcs, itertools.chain(
itertools.product(values, values),
itertools.product(neg_values, values),
itertools.product(values, neg_values),
itertools.product(neg_values, neg_values),
)):
def assert_unsupported_error(func, x, y):
if func is torch_func:
# makes sure we use the SymPy types
x = sympy.sympify(x)
y = sympy.sympify(y)
err = (
rf"unsupported operand type\(s\) for //: "
rf"'{type(x).__name__}' and '{type(y).__name__}'"
rf", expected integer or real"
)
else:
err = ""
self.assertRaisesRegex(TypeError, err, lambda: func(x, y))
if type(x) is complex or type(y) is complex:
# complex is not supported by floordiv
assert_unsupported_error(func, x, y)
elif (type(x) is bool or type(y) is bool) and func is torch_func:
# bools are not supported in arithmetic exprs in SymPy
assert_unsupported_error(func, x, y)
elif (type(x) is bool or type(y) is bool) and y != 0:
# test bools against SymPy ints unless it's a div by zero
int_x = int(x) if type(x) is bool else x
int_y = int(y) if type(y) is bool else y
self.assertEqual(func(x, y), other_func(func, int_x, int_y))
elif y == 0:
# div by zero
if func is torch_func:
err = "division by zero"
else:
err = ""
self.assertRaisesRegex(ZeroDivisionError, err, lambda: func(x, y))
else:
# otherwise, compare results
self.assertEqual(func(x, y), other_func(func, x, y))
@skipIfNoSympy
def test_floordiv_simplify(self):
# Checks that we eval exprs without free vars no matter which
# simplify/eval func is called.
expr = FloorDiv(6.28, (FloorDiv(6.28, 3.14)))
shape_env = ShapeEnv()
# All these should return the same result.
self.assertEqual(expr, 3) # fully eval'd automatically
self.assertEqual(expr.doit(deep=False), 3)
self.assertEqual(expr.doit(deep=True), 3)
self.assertEqual(sympy.simplify(expr), 3)
self.assertEqual(shape_env.simplify(expr), 3)
self.assertEqual(shape_env.evaluate_expr(expr), 3)
@skipIfNoSympy
def test_floordiv_assumptions(self):
# We define two Symbols (with different names) for each type to make
# sure the behavior is consistent regardless of whether both arguments
# are the same object or not.
cases = (
sympy.Symbol("i1", integer=True),
sympy.Symbol("i2", integer=True),
sympy.Symbol("r1", real=True),
sympy.Symbol("r2", real=True),
sympy.Symbol("c1", complex=True, real=False, integer=False),
sympy.Symbol("c2", complex=True, real=False, integer=False),
sympy.Symbol("s1"),
sympy.Symbol("s2"),
)
for base, divisor in itertools.product(cases, repeat=2):
def op():
return FloorDiv(base, divisor)
def is_complex(x):
return x.is_integer is False and x.is_real is False and x.is_complex
if is_complex(base) or is_complex(divisor):
self.assertRaisesRegex(
TypeError,
(r"unsupported operand type\(s\) for //: 'Symbol' and 'Symbol',"
r" expected integer or real"),
op)
continue
op = op()
# In regular Python, x//x == 1.0 if x is a float, but FloorDiv
# always returns an integer 1 when both args are the same object.
# This even works for Symbols with no assumptions specified.
if base is divisor or (base.is_integer and divisor.is_integer):
self.assertTrue(op.is_integer)
self.assertTrue(op.is_real)
else:
self.assertEqual(op.is_integer, None)
self.assertTrue(op.is_real)
if __name__ == '__main__':
run_tests()

View File

@ -24,6 +24,7 @@ try:
import sympy # type: ignore[import]
from sympy.printing.precedence import precedence # type: ignore[import] # noqa: F401
from sympy.printing.str import StrPrinter # type: ignore[import]
from sympy.core.logic import fuzzy_and, fuzzy_or # type: ignore[import]
HAS_SYMPY = True
except ImportError:
HAS_SYMPY = False
@ -210,6 +211,19 @@ class SymNode:
if HAS_SYMPY:
# NOTE [ SymPy eval and assumptions ]
# In eval, we only return values in cases where we always want to evaluate.
# In other cases, the result will just be FloorDiv(a, b), which needs to be
# evaluated later if necessary.
#
# We also define is_real=True and provide _eval_* methods to make the SymPy
# assumptions system aware of Python floordiv semantics. For instance, this
# ensures that correct assumptions are propagated when working with SymPy
# Symbols. Two integer Symbols should return an integer result.
#
# https://peps.python.org/pep-0238/#semantics-of-floor-division
# https://docs.sympy.org/latest/guides/assumptions.html#implementing-assumptions-handlers
# https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval
class FloorDiv(sympy.Function):
"""
We maintain this so that:
@ -219,21 +233,61 @@ if HAS_SYMPY:
nargs = (2,)
precedence = 50 # precedence of mul # noqa: F811
def _sympystr(self, printer):
lhs = self.args[0]
rhs = self.args[1]
lhs_str = printer.parenthesize(lhs, self.precedence)
rhs_str = printer.parenthesize(rhs, self.precedence)
return f"{lhs_str}//{rhs_str}"
# Default return type. For instance, this applies when both arguments
# are Symbols without any assumptions.
# See NOTE [ SymPy eval and assumptions ]
is_real = True
@property
def base(self):
return self.args[0]
@property
def divisor(self):
return self.args[1]
def _sympystr(self, printer):
base = printer.parenthesize(self.base, self.precedence)
divisor = printer.parenthesize(self.divisor, self.precedence)
return f"{base}//{divisor}"
# Assumptions based on argument types.
# See NOTE [ SymPy eval and assumptions ]
def _eval_is_real(self):
return fuzzy_or([self.base.is_real, self.divisor.is_real])
def _eval_is_integer(self):
return fuzzy_and([self.base.is_integer, self.divisor.is_integer])
# Automatic evaluation.
# See NOTE [ SymPy eval and assumptions ]
@classmethod
def eval(cls, base, divisor):
if base == 0:
return sympy.Integer(0)
def check_supported_type(x):
if (x.is_integer is False and x.is_real is False and x.is_complex) or x.is_Boolean:
raise TypeError(
f"unsupported operand type(s) for //: "
f"'{type(base).__name__}' and '{type(divisor).__name__}'"
f", expected integer or real")
check_supported_type(base)
check_supported_type(divisor)
# We don't provide the same error message as in Python because SymPy
# makes it difficult to check the types.
if divisor.is_zero:
raise ZeroDivisionError("division by zero")
# We don't cast the return type as in Python because SymPy makes it
# difficult to check the types.
if base.is_zero:
return sympy.S.Zero
if divisor == 1:
return base
return sympy.floor(base)
if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
return base // divisor
if isinstance(base, (sympy.Integer, sympy.Float)) and isinstance(divisor, (sympy.Integer, sympy.Float)):
return sympy.floor(base / divisor)
if isinstance(base, FloorDiv):
return FloorDiv(base.args[0], base.args[1] * divisor)
@ -243,6 +297,7 @@ if HAS_SYMPY:
sympy.simplify(base / gcd), sympy.simplify(divisor / gcd)
)
# Methods that have a `__foo__` as well as `__rfoo__`
reflectable_magic_methods = {
'add': lambda a, b: a + b,