[fx] fix type promotion in binary_magic_impl (#91376)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91376
Approved by: https://github.com/ezyang, https://github.com/albanD
This commit is contained in:
Nikita Karetnikov 2023-01-16 10:52:36 +01:00 committed by PyTorch MergeBot
parent d13207c7ad
commit 88b3810c94
2 changed files with 121 additions and 84 deletions

View File

@ -10,19 +10,17 @@ import unittest
import torch
import operator
import itertools
import random
import contextlib
import math
import builtins
import atexit
import io
import os
from torch.utils._pytree import tree_map
from torch.fx.experimental import symbolic_shapes
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import FloorDiv, ShapeEnv, sym_float, guard_int, SymNode, sym_sqrt, sym_int, to_node
from torch.fx.experimental.symbolic_shapes import FloorDiv, ShapeEnv, \
guard_int, guard_float, SymNode, sym_sqrt, sym_int, sym_float, to_node
from torch.utils._python_dispatch import TorchDispatchMode
from torch import SymInt
from torch import SymInt, SymFloat
aten = torch.ops.aten
@ -443,34 +441,6 @@ class f(torch.nn.Module):
getitem_1: b8[s0 + s2, 2*s1] = native_dropout[1]; native_dropout = None
return (getitem, getitem_1)""") # noqa: B950
# This environment variable controls whether or not we print expected failure
# lists at the end of a test suite run. The intended usage looks like this:
#
# 1. Run `PYTORCH_COLLECT_EXPECT=1 python test/test_dynamic_shapes.py -k TestSymNumberMagicMethods`.
# 2. Given the printed xfail list, add them to the set expected_failure_sym_magic_methods.
COLLECT_EXPECT = os.getenv('PYTORCH_COLLECT_EXPECT', '0') == '1'
seen_failed = []
def print_seen():
out = []
for key, reason in seen_failed:
# Make sure the generated line is lint clean
msg = f" {key}, # {reason}"
eol = msg.find("\n")
if eol != -1:
msg = msg[:eol]
out.append(msg[:120])
print("expected_failure_sym_magic_methods = {")
print("\n".join(out))
print("}")
if COLLECT_EXPECT:
atexit.register(print_seen)
expected_failure_sym_magic_methods = {
}
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
class TestSymNumberMagicMethods(TestCase):
def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn):
@ -484,23 +454,28 @@ class TestSymNumberMagicMethods(TestCase):
return torch.SymFloat(to_node(seed_node, inp))
def maybe_xfail(inp1, inp2):
key = (fn, type(inp1).__name__, type(inp2).__name__)
if COLLECT_EXPECT:
@contextlib.contextmanager
def context():
try:
yield
except (TypeError, AssertionError) as e:
seen_failed.append((key, str(e)))
return context()
if key in expected_failure_sym_magic_methods:
return self.assertRaises((TypeError, AssertionError))
if fn == "sym_sqrt" and inp1 < 0 and type(inp1) in (SymFloat, SymInt):
# TypeError: Cannot convert complex to float
return self.assertRaises((TypeError,))
elif fn == "sym_sqrt" and inp1 < 0:
# ValueError: math domain error
return self.assertRaises((ValueError,))
elif fn in ("truediv", "floordiv", "mod") and inp2 == 0:
# ZeroDivisionError: division by zero
return self.assertRaises((ZeroDivisionError,))
elif fn == "pow" and inp1 == 0 and inp2 < 0:
# ZeroDivisionError: 0.0 cannot be raised to a negative power
return self.assertRaises((ZeroDivisionError,))
elif fn == "pow" and inp1 < 0 and inp2 in (2.5, -2.5) and (
type(inp1) in (SymFloat, SymInt) or
type(inp2) in (SymFloat, SymInt)
):
# Complex result, which we do not support:
# TypeError: Cannot convert complex to float
return self.assertRaises((TypeError,))
else:
return contextlib.nullcontext()
# These functions might return plain int/float
has_valid_downcast = fn in ["min", "max"]
if fn in symbolic_shapes.magic_methods_on_builtins:
lambda_apply = getattr(builtins, fn)
elif fn in symbolic_shapes.magic_methods_on_math:
@ -510,26 +485,20 @@ class TestSymNumberMagicMethods(TestCase):
else:
lambda_apply = getattr(operator, fn)
if fn in symbolic_shapes.always_float_magic_methods:
tp = "float"
elif fn in symbolic_shapes.always_int_magic_methods:
tp = "int"
elif is_unary_fn:
tp = "float" if isinstance(inp1, float) else "int"
else:
tp = "float" if any(isinstance(i, float) for i in [inp1, inp2]) else "int"
def guard_fn(v):
try:
if fn in symbolic_shapes.always_bool_magic_methods:
return bool(v)
else:
return getattr(v.node, f"guard_{tp}")("", 0)
if type(v) in (SymFloat, float):
return guard_float(v)
else: # SymInt, int
res = guard_int(v)
# We make sure that bools are represented as SymInts first
# by calling guard_int, but then cast for compatibility with
# a reference impl since we don't have SymBool.
if fn in symbolic_shapes.always_bool_magic_methods:
return bool(res)
return res
except Exception as e:
if has_valid_downcast:
return v
else:
raise e
raise e
# Get reference result
with maybe_xfail(inp1, inp2):
@ -545,7 +514,8 @@ class TestSymNumberMagicMethods(TestCase):
out = lambda_apply(sym_inp1)
else:
out = lambda_apply(sym_inp1, inp2)
self.assertEqual(guard_fn(out), ref_out)
out = guard_fn(out)
self.assertEqual(out, ref_out)
if is_unary_fn:
return
@ -554,12 +524,14 @@ class TestSymNumberMagicMethods(TestCase):
sym_inp2 = get_sym_inp(inp2)
with maybe_xfail(inp1, sym_inp2):
out = lambda_apply(inp1, sym_inp2)
self.assertEqual(guard_fn(out), ref_out)
out = guard_fn(out)
self.assertEqual(out, ref_out)
# Symified both args
with maybe_xfail(sym_inp1, sym_inp2):
out = lambda_apply(sym_inp1, sym_inp2)
self.assertEqual(guard_fn(out), ref_out)
out = guard_fn(out)
self.assertEqual(out, ref_out)
@parametrize("fn", list(symbolic_shapes.magic_methods.keys()))
@ -574,18 +546,30 @@ class TestSymNumberMagicMethods(TestCase):
if is_unary_fn and second_type == "float":
self.skipTest(f"{fn} is unary and already tested")
# We could pass int/float directly for types but then the
# mangled test name is bad
inp1 = random.random() * 2.5
if first_type == "int":
inp1 = int(inp1)
inp2 = random.random() * 2.5
if second_type == "int":
inp2 = int(inp2)
# Only floats here since these will be converted to int if necessary.
# We also ignore complex and bool.
values = (
0.0,
1.0,
2.5,
)
shape_env = ShapeEnv()
neg_values = tuple(-x for x in values)
self._do_test(fn, inp1, inp2, shape_env, is_unary_fn)
for inp1, inp2 in itertools.chain(
itertools.product(values, values),
itertools.product(values, neg_values),
itertools.product(neg_values, values),
itertools.product(neg_values, neg_values),
):
if first_type == "int":
inp1 = int(inp1)
if second_type == "int":
inp2 = int(inp2)
shape_env = ShapeEnv()
self._do_test(fn, inp1, inp2, shape_env, is_unary_fn)
instantiate_parametrized_tests(TestSymNumberMagicMethods)

View File

@ -33,7 +33,7 @@ aten = torch._ops.ops.aten # type: ignore[has-type]
__all__ = [
"has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv",
"SymDispatchMode", "FloorDiv", "guard_int", "wrap_node",
"SymDispatchMode", "FloorDiv", "guard_int", "guard_float", "wrap_node",
]
SYM_FUNCTION_MODE = None
@ -105,6 +105,12 @@ def guard_int(a):
assert isinstance(a, int)
return a
def guard_float(a):
if isinstance(a, SymFloat):
return a.node.guard_float("", 0) # NB: uses Python backtrace
assert isinstance(a, float)
return a
# Drop in replacement for math.sqrt
def sym_sqrt(a):
if hasattr(a, '__sym_sqrt__'):
@ -199,7 +205,15 @@ class SymNode:
def guard_int(self, file, line):
# TODO: use the file/line for some useful diagnostic on why a
# guard occurred
return int(self.shape_env.evaluate_expr(self.expr))
# Because there is no SymBool, we wrap bools into SymInt during
# construction. So we have to handle bools here.
res = self.shape_env.evaluate_expr(self.expr)
if res is sympy.sympify(False):
return 0
elif res is sympy.sympify(True):
return 1
else:
return int(res)
def guard_float(self, file, line):
# TODO: use the file/line for some useful diagnostic on why a
@ -211,6 +225,28 @@ class SymNode:
if HAS_SYMPY:
# Overloaded to be compatible with regular Python.
# https://github.com/pytorch/pytorch/issues/90900
class Pow(sympy.Function):
@classmethod
def eval(cls, base, exp):
if exp == 0:
return sympy.Integer(1)
elif base == 0 and exp < 0:
raise ZeroDivisionError(f"{base} cannot be raised to a negative power")
else:
return base ** exp
# Overloaded to be compatible with regular Python.
# https://github.com/pytorch/pytorch/issues/90900
class TrueDiv(sympy.Function):
@classmethod
def eval(cls, base, divisor):
if divisor == 0:
raise ZeroDivisionError("division by zero")
else:
return base / divisor
# NOTE [ SymPy eval and assumptions ]
# In eval, we only return values in cases where we always want to evaluate.
# In other cases, the result will just be FloorDiv(a, b), which needs to be
@ -304,8 +340,8 @@ reflectable_magic_methods = {
'sub': lambda a, b: a - b,
'mul': lambda a, b: a * b,
'mod': lambda a, b: a % b,
'pow': lambda a, b: a ** b,
'truediv': lambda a, b: a / b,
'pow': lambda a, b: Pow(a, b),
'truediv': lambda a, b: TrueDiv(a, b),
'floordiv': lambda a, b: FloorDiv(a, b),
}
@ -337,7 +373,7 @@ magic_methods_on_builtins = {"min", "max"}
magic_methods_on_math = {"ceil", "floor"}
magic_methods_on_submodule = {"sym_float", "sym_sqrt"}
always_float_magic_methods = {"truediv", "sym_float", "sym_sqrt"}
always_float_magic_methods = {"truediv", "sym_float", "sym_sqrt", "pow"}
always_int_magic_methods = {"ceil", "floor"}
always_bool_magic_methods = {"eq", "gt", "lt", "le", "ge"}
@ -376,13 +412,30 @@ def _make_node_magic(method, func):
raise
out = sympy.expand(out)
pytype: Type
# This is not strictly correct. In Python, a**b may return complex when
# a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This
# returns a float while both arguments are ints: 2**(-1). Also, max and
# min do not type promote. To avoid having data-dependent control flow
# here, we just set the type to float if one of the args is a float. In
# case of a type mismatch, we assume that it will be detected during
# evaluation.
if method in always_float_magic_methods:
pytype = float
elif method in ("min", "max") and self.pytype is int and other.pytype is int:
# These ops don't type promote. The result type depends on arg
# values. But when both args are ints, we can be sure the result is
# an int as well. Otherwise, we assume the result is a float and let
# one of the cases below handle that. That's not strictly correct,
# but it's the best we can do without being data-dependent.
pytype = int
elif method in always_bool_magic_methods:
# This should return bool, but we have no SymBool, see wrap_node.
pytype = int
elif self.pytype is float or other.pytype is float:
pytype = float
else:
pytype = self.pytype
# TODO: relational operators actually technically return a
# PySymBool, this is a type error
return SymNode(out, self.shape_env, pytype)
def unary_magic_impl(self):