mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[fx] fix type promotion in binary_magic_impl (#91376)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91376 Approved by: https://github.com/ezyang, https://github.com/albanD
This commit is contained in:
parent
d13207c7ad
commit
88b3810c94
|
|
@ -10,19 +10,17 @@ import unittest
|
|||
import torch
|
||||
import operator
|
||||
import itertools
|
||||
import random
|
||||
import contextlib
|
||||
import math
|
||||
import builtins
|
||||
import atexit
|
||||
import io
|
||||
import os
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.fx.experimental import symbolic_shapes
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.experimental.symbolic_shapes import FloorDiv, ShapeEnv, sym_float, guard_int, SymNode, sym_sqrt, sym_int, to_node
|
||||
from torch.fx.experimental.symbolic_shapes import FloorDiv, ShapeEnv, \
|
||||
guard_int, guard_float, SymNode, sym_sqrt, sym_int, sym_float, to_node
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch import SymInt
|
||||
from torch import SymInt, SymFloat
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
|
@ -443,34 +441,6 @@ class f(torch.nn.Module):
|
|||
getitem_1: b8[s0 + s2, 2*s1] = native_dropout[1]; native_dropout = None
|
||||
return (getitem, getitem_1)""") # noqa: B950
|
||||
|
||||
# This environment variable controls whether or not we print expected failure
|
||||
# lists at the end of a test suite run. The intended usage looks like this:
|
||||
#
|
||||
# 1. Run `PYTORCH_COLLECT_EXPECT=1 python test/test_dynamic_shapes.py -k TestSymNumberMagicMethods`.
|
||||
# 2. Given the printed xfail list, add them to the set expected_failure_sym_magic_methods.
|
||||
COLLECT_EXPECT = os.getenv('PYTORCH_COLLECT_EXPECT', '0') == '1'
|
||||
|
||||
seen_failed = []
|
||||
def print_seen():
|
||||
out = []
|
||||
for key, reason in seen_failed:
|
||||
# Make sure the generated line is lint clean
|
||||
msg = f" {key}, # {reason}"
|
||||
eol = msg.find("\n")
|
||||
if eol != -1:
|
||||
msg = msg[:eol]
|
||||
out.append(msg[:120])
|
||||
|
||||
print("expected_failure_sym_magic_methods = {")
|
||||
print("\n".join(out))
|
||||
print("}")
|
||||
|
||||
if COLLECT_EXPECT:
|
||||
atexit.register(print_seen)
|
||||
|
||||
expected_failure_sym_magic_methods = {
|
||||
}
|
||||
|
||||
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
|
||||
class TestSymNumberMagicMethods(TestCase):
|
||||
def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn):
|
||||
|
|
@ -484,23 +454,28 @@ class TestSymNumberMagicMethods(TestCase):
|
|||
return torch.SymFloat(to_node(seed_node, inp))
|
||||
|
||||
def maybe_xfail(inp1, inp2):
|
||||
key = (fn, type(inp1).__name__, type(inp2).__name__)
|
||||
if COLLECT_EXPECT:
|
||||
@contextlib.contextmanager
|
||||
def context():
|
||||
try:
|
||||
yield
|
||||
except (TypeError, AssertionError) as e:
|
||||
seen_failed.append((key, str(e)))
|
||||
return context()
|
||||
|
||||
if key in expected_failure_sym_magic_methods:
|
||||
return self.assertRaises((TypeError, AssertionError))
|
||||
if fn == "sym_sqrt" and inp1 < 0 and type(inp1) in (SymFloat, SymInt):
|
||||
# TypeError: Cannot convert complex to float
|
||||
return self.assertRaises((TypeError,))
|
||||
elif fn == "sym_sqrt" and inp1 < 0:
|
||||
# ValueError: math domain error
|
||||
return self.assertRaises((ValueError,))
|
||||
elif fn in ("truediv", "floordiv", "mod") and inp2 == 0:
|
||||
# ZeroDivisionError: division by zero
|
||||
return self.assertRaises((ZeroDivisionError,))
|
||||
elif fn == "pow" and inp1 == 0 and inp2 < 0:
|
||||
# ZeroDivisionError: 0.0 cannot be raised to a negative power
|
||||
return self.assertRaises((ZeroDivisionError,))
|
||||
elif fn == "pow" and inp1 < 0 and inp2 in (2.5, -2.5) and (
|
||||
type(inp1) in (SymFloat, SymInt) or
|
||||
type(inp2) in (SymFloat, SymInt)
|
||||
):
|
||||
# Complex result, which we do not support:
|
||||
# TypeError: Cannot convert complex to float
|
||||
return self.assertRaises((TypeError,))
|
||||
else:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
# These functions might return plain int/float
|
||||
has_valid_downcast = fn in ["min", "max"]
|
||||
if fn in symbolic_shapes.magic_methods_on_builtins:
|
||||
lambda_apply = getattr(builtins, fn)
|
||||
elif fn in symbolic_shapes.magic_methods_on_math:
|
||||
|
|
@ -510,26 +485,20 @@ class TestSymNumberMagicMethods(TestCase):
|
|||
else:
|
||||
lambda_apply = getattr(operator, fn)
|
||||
|
||||
if fn in symbolic_shapes.always_float_magic_methods:
|
||||
tp = "float"
|
||||
elif fn in symbolic_shapes.always_int_magic_methods:
|
||||
tp = "int"
|
||||
elif is_unary_fn:
|
||||
tp = "float" if isinstance(inp1, float) else "int"
|
||||
else:
|
||||
tp = "float" if any(isinstance(i, float) for i in [inp1, inp2]) else "int"
|
||||
|
||||
def guard_fn(v):
|
||||
try:
|
||||
if fn in symbolic_shapes.always_bool_magic_methods:
|
||||
return bool(v)
|
||||
else:
|
||||
return getattr(v.node, f"guard_{tp}")("", 0)
|
||||
if type(v) in (SymFloat, float):
|
||||
return guard_float(v)
|
||||
else: # SymInt, int
|
||||
res = guard_int(v)
|
||||
# We make sure that bools are represented as SymInts first
|
||||
# by calling guard_int, but then cast for compatibility with
|
||||
# a reference impl since we don't have SymBool.
|
||||
if fn in symbolic_shapes.always_bool_magic_methods:
|
||||
return bool(res)
|
||||
return res
|
||||
except Exception as e:
|
||||
if has_valid_downcast:
|
||||
return v
|
||||
else:
|
||||
raise e
|
||||
raise e
|
||||
|
||||
# Get reference result
|
||||
with maybe_xfail(inp1, inp2):
|
||||
|
|
@ -545,7 +514,8 @@ class TestSymNumberMagicMethods(TestCase):
|
|||
out = lambda_apply(sym_inp1)
|
||||
else:
|
||||
out = lambda_apply(sym_inp1, inp2)
|
||||
self.assertEqual(guard_fn(out), ref_out)
|
||||
out = guard_fn(out)
|
||||
self.assertEqual(out, ref_out)
|
||||
|
||||
if is_unary_fn:
|
||||
return
|
||||
|
|
@ -554,12 +524,14 @@ class TestSymNumberMagicMethods(TestCase):
|
|||
sym_inp2 = get_sym_inp(inp2)
|
||||
with maybe_xfail(inp1, sym_inp2):
|
||||
out = lambda_apply(inp1, sym_inp2)
|
||||
self.assertEqual(guard_fn(out), ref_out)
|
||||
out = guard_fn(out)
|
||||
self.assertEqual(out, ref_out)
|
||||
|
||||
# Symified both args
|
||||
with maybe_xfail(sym_inp1, sym_inp2):
|
||||
out = lambda_apply(sym_inp1, sym_inp2)
|
||||
self.assertEqual(guard_fn(out), ref_out)
|
||||
out = guard_fn(out)
|
||||
self.assertEqual(out, ref_out)
|
||||
|
||||
|
||||
@parametrize("fn", list(symbolic_shapes.magic_methods.keys()))
|
||||
|
|
@ -574,18 +546,30 @@ class TestSymNumberMagicMethods(TestCase):
|
|||
if is_unary_fn and second_type == "float":
|
||||
self.skipTest(f"{fn} is unary and already tested")
|
||||
|
||||
# We could pass int/float directly for types but then the
|
||||
# mangled test name is bad
|
||||
inp1 = random.random() * 2.5
|
||||
if first_type == "int":
|
||||
inp1 = int(inp1)
|
||||
inp2 = random.random() * 2.5
|
||||
if second_type == "int":
|
||||
inp2 = int(inp2)
|
||||
# Only floats here since these will be converted to int if necessary.
|
||||
# We also ignore complex and bool.
|
||||
values = (
|
||||
0.0,
|
||||
1.0,
|
||||
2.5,
|
||||
)
|
||||
|
||||
shape_env = ShapeEnv()
|
||||
neg_values = tuple(-x for x in values)
|
||||
|
||||
self._do_test(fn, inp1, inp2, shape_env, is_unary_fn)
|
||||
for inp1, inp2 in itertools.chain(
|
||||
itertools.product(values, values),
|
||||
itertools.product(values, neg_values),
|
||||
itertools.product(neg_values, values),
|
||||
itertools.product(neg_values, neg_values),
|
||||
):
|
||||
if first_type == "int":
|
||||
inp1 = int(inp1)
|
||||
if second_type == "int":
|
||||
inp2 = int(inp2)
|
||||
|
||||
shape_env = ShapeEnv()
|
||||
|
||||
self._do_test(fn, inp1, inp2, shape_env, is_unary_fn)
|
||||
|
||||
instantiate_parametrized_tests(TestSymNumberMagicMethods)
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ aten = torch._ops.ops.aten # type: ignore[has-type]
|
|||
|
||||
__all__ = [
|
||||
"has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv",
|
||||
"SymDispatchMode", "FloorDiv", "guard_int", "wrap_node",
|
||||
"SymDispatchMode", "FloorDiv", "guard_int", "guard_float", "wrap_node",
|
||||
]
|
||||
|
||||
SYM_FUNCTION_MODE = None
|
||||
|
|
@ -105,6 +105,12 @@ def guard_int(a):
|
|||
assert isinstance(a, int)
|
||||
return a
|
||||
|
||||
def guard_float(a):
|
||||
if isinstance(a, SymFloat):
|
||||
return a.node.guard_float("", 0) # NB: uses Python backtrace
|
||||
assert isinstance(a, float)
|
||||
return a
|
||||
|
||||
# Drop in replacement for math.sqrt
|
||||
def sym_sqrt(a):
|
||||
if hasattr(a, '__sym_sqrt__'):
|
||||
|
|
@ -199,7 +205,15 @@ class SymNode:
|
|||
def guard_int(self, file, line):
|
||||
# TODO: use the file/line for some useful diagnostic on why a
|
||||
# guard occurred
|
||||
return int(self.shape_env.evaluate_expr(self.expr))
|
||||
# Because there is no SymBool, we wrap bools into SymInt during
|
||||
# construction. So we have to handle bools here.
|
||||
res = self.shape_env.evaluate_expr(self.expr)
|
||||
if res is sympy.sympify(False):
|
||||
return 0
|
||||
elif res is sympy.sympify(True):
|
||||
return 1
|
||||
else:
|
||||
return int(res)
|
||||
|
||||
def guard_float(self, file, line):
|
||||
# TODO: use the file/line for some useful diagnostic on why a
|
||||
|
|
@ -211,6 +225,28 @@ class SymNode:
|
|||
|
||||
|
||||
if HAS_SYMPY:
|
||||
# Overloaded to be compatible with regular Python.
|
||||
# https://github.com/pytorch/pytorch/issues/90900
|
||||
class Pow(sympy.Function):
|
||||
@classmethod
|
||||
def eval(cls, base, exp):
|
||||
if exp == 0:
|
||||
return sympy.Integer(1)
|
||||
elif base == 0 and exp < 0:
|
||||
raise ZeroDivisionError(f"{base} cannot be raised to a negative power")
|
||||
else:
|
||||
return base ** exp
|
||||
|
||||
# Overloaded to be compatible with regular Python.
|
||||
# https://github.com/pytorch/pytorch/issues/90900
|
||||
class TrueDiv(sympy.Function):
|
||||
@classmethod
|
||||
def eval(cls, base, divisor):
|
||||
if divisor == 0:
|
||||
raise ZeroDivisionError("division by zero")
|
||||
else:
|
||||
return base / divisor
|
||||
|
||||
# NOTE [ SymPy eval and assumptions ]
|
||||
# In eval, we only return values in cases where we always want to evaluate.
|
||||
# In other cases, the result will just be FloorDiv(a, b), which needs to be
|
||||
|
|
@ -304,8 +340,8 @@ reflectable_magic_methods = {
|
|||
'sub': lambda a, b: a - b,
|
||||
'mul': lambda a, b: a * b,
|
||||
'mod': lambda a, b: a % b,
|
||||
'pow': lambda a, b: a ** b,
|
||||
'truediv': lambda a, b: a / b,
|
||||
'pow': lambda a, b: Pow(a, b),
|
||||
'truediv': lambda a, b: TrueDiv(a, b),
|
||||
'floordiv': lambda a, b: FloorDiv(a, b),
|
||||
}
|
||||
|
||||
|
|
@ -337,7 +373,7 @@ magic_methods_on_builtins = {"min", "max"}
|
|||
magic_methods_on_math = {"ceil", "floor"}
|
||||
magic_methods_on_submodule = {"sym_float", "sym_sqrt"}
|
||||
|
||||
always_float_magic_methods = {"truediv", "sym_float", "sym_sqrt"}
|
||||
always_float_magic_methods = {"truediv", "sym_float", "sym_sqrt", "pow"}
|
||||
always_int_magic_methods = {"ceil", "floor"}
|
||||
always_bool_magic_methods = {"eq", "gt", "lt", "le", "ge"}
|
||||
|
||||
|
|
@ -376,13 +412,30 @@ def _make_node_magic(method, func):
|
|||
raise
|
||||
out = sympy.expand(out)
|
||||
pytype: Type
|
||||
# This is not strictly correct. In Python, a**b may return complex when
|
||||
# a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This
|
||||
# returns a float while both arguments are ints: 2**(-1). Also, max and
|
||||
# min do not type promote. To avoid having data-dependent control flow
|
||||
# here, we just set the type to float if one of the args is a float. In
|
||||
# case of a type mismatch, we assume that it will be detected during
|
||||
# evaluation.
|
||||
if method in always_float_magic_methods:
|
||||
pytype = float
|
||||
elif method in ("min", "max") and self.pytype is int and other.pytype is int:
|
||||
# These ops don't type promote. The result type depends on arg
|
||||
# values. But when both args are ints, we can be sure the result is
|
||||
# an int as well. Otherwise, we assume the result is a float and let
|
||||
# one of the cases below handle that. That's not strictly correct,
|
||||
# but it's the best we can do without being data-dependent.
|
||||
pytype = int
|
||||
elif method in always_bool_magic_methods:
|
||||
# This should return bool, but we have no SymBool, see wrap_node.
|
||||
pytype = int
|
||||
elif self.pytype is float or other.pytype is float:
|
||||
pytype = float
|
||||
else:
|
||||
pytype = self.pytype
|
||||
|
||||
# TODO: relational operators actually technically return a
|
||||
# PySymBool, this is a type error
|
||||
return SymNode(out, self.shape_env, pytype)
|
||||
|
||||
def unary_magic_impl(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user