Track and record hint on SymNode and use when possible (#94201)

Historically, we work out `size_hint` by working it out on the fly by doing a substitution on the sympy expression with the `var_to_val` mapping. With this change, we also maintain the hint directly on SymNode (in `expr._hint`) and use it in lieu of Sympy substitution when it is available (mostly guards on SymInt, etc; in particular, in idiomatic Inductor code, we typically manipulate Sympy expressions directly and so do not have a way to conveniently maintain hints.)

While it's possible this will give us modest performance improvements, this is not the point of this PR; the goal is to make it easier to carefully handle unbacked SymInts, where hints are expected not to be available. You can now easily test if a SymInt is backed or not by checking `symint.node.hint is None`.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94201
Approved by: https://github.com/voznesenskym
This commit is contained in:
Edward Z. Yang 2023-02-08 14:03:49 -05:00 committed by PyTorch MergeBot
parent b5ef37b9a4
commit dc70b00d0b
8 changed files with 131 additions and 50 deletions

View File

@ -125,10 +125,11 @@ def create_symbolic_tensor(name, arg, shape_env):
shape_env.create_symbolic_sizes_strides_storage_offset(arg, source=ConstantSource(name))
return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, sym_storage_offset)
def create_symint(shape_env, i):
def create_symint(shape_env, i: int):
from torch._dynamo.source import ConstantSource
return shape_env.create_symintnode(
shape_env.create_symbol(i, source=ConstantSource(f"__testing_only{len(shape_env.var_to_val)}"))
shape_env.create_symbol(i, source=ConstantSource(f"__testing_only{len(shape_env.var_to_val)}")),
hint=i
)
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
@ -478,10 +479,7 @@ class TestSymNumberMagicMethods(TestCase):
return torch.SymFloat(to_node(seed_node, inp))
def maybe_xfail(inp1, inp2):
if fn == "sym_sqrt" and inp1 < 0 and type(inp1) in (SymFloat, SymInt):
# TypeError: Cannot convert complex to float
return self.assertRaises((TypeError,))
elif fn == "sym_sqrt" and inp1 < 0:
if fn == "sym_sqrt" and inp1 < 0:
# ValueError: math domain error
return self.assertRaises((ValueError,))
elif fn in ("truediv", "floordiv", "mod") and inp2 == 0:

View File

@ -413,6 +413,9 @@ def sym_max(a, b):
if isinstance(a, (SymInt, SymFloat)):
return a.__sym_max__(b)
elif isinstance(b, (SymInt, SymFloat)):
# NB: If you actually care about preserving output type exactly
# if you do something like max(0, 0.0), it is NOT sound to treat
# min/max as commutative
return b.__sym_max__(a)
return builtins.max(a, b) # type: ignore[operator]

View File

@ -687,7 +687,7 @@ class VariableBuilder:
):
shape_env = self.tx.output.shape_env
wrapped_value = shape_env.create_symintnode(
shape_env.create_symbol(value, source=self.source)
shape_env.create_symbol(value, source=self.source), hint=value
)
self.tx.output.tracked_fakes.append(
TrackedFake(wrapped_value, self.source)

View File

@ -6,6 +6,7 @@ from typing import Dict, List
import torch.fx
import torch.random
from torch.fx.experimental.symbolic_shapes import guard_scalar
from .. import config, variables
from ..exc import unimplemented
@ -460,9 +461,7 @@ class SymNodeVariable(VariableTracker):
return self.proxy
def evaluate_expr(self, output_graph):
if not isinstance(self.sym_num, torch.SymInt):
return self.sym_num
return output_graph.shape_env.evaluate_expr(self.sym_num.node.expr)
return guard_scalar(self.sym_num)
def call_method(
self,

View File

@ -1,4 +1,5 @@
from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types
from torch.fx.experimental.symbolic_shapes import hint_int
import torch
import torch.fx as fx
import operator
@ -221,21 +222,14 @@ def _tensor_nbytes(numel, dtype):
return numel * sizes[dtype]
def _size_of(node: fx.Node) -> int:
def to_size_hint(s):
if isinstance(s, torch.SymInt):
py_s = s.node
return py_s.shape_env.size_hint(py_s.expr)
assert isinstance(s, int)
return s
if 'val' in node.meta:
val = node.meta['val']
if isinstance(val, py_sym_types):
return 1
elif isinstance(val, (list, tuple)):
return sum(_tensor_nbytes(to_size_hint(n.numel()), n.dtype) for n in val if isinstance(n, torch.Tensor))
return sum(_tensor_nbytes(hint_int(n.numel()), n.dtype) for n in val if isinstance(n, torch.Tensor))
elif isinstance(val, torch.Tensor):
return _tensor_nbytes(to_size_hint(val.numel()), val.dtype)
return _tensor_nbytes(hint_int(val.numel()), val.dtype)
raise RuntimeError(f"Unknown metadata type {type(val)}")

View File

@ -2482,7 +2482,7 @@ class ExternKernel(InputsKernel):
tensor_args.append(arg)
else:
if isinstance(arg, sympy.Expr):
arg = V.graph.sizevars.shape_env.create_symintnode(arg)
arg = V.graph.sizevars.shape_env.create_symintnode(arg, hint=None)
non_tensor_args.append(arg)
def unflatten_args(new_tensor_args, new_non_tensor_args):

View File

@ -107,7 +107,7 @@ def convert_shape_to_symint(
if isinstance(i, int)
else int(i)
if isinstance(i, sympy.Integer)
else V.graph.sizevars.shape_env.create_symintnode(i)
else V.graph.sizevars.shape_env.create_symintnode(i, hint=None)
for i in lst
]

View File

@ -37,8 +37,8 @@ aten = torch._ops.ops.aten # type: ignore[has-type]
__all__ = [
"has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv",
"SymDispatchMode", "FloorDiv", "guard_int", "guard_float", "wrap_node",
"method_to_operator", "SYMPY_INTERP",
"SymDispatchMode", "FloorDiv", "guard_int", "guard_float", "guard_scalar", "wrap_node",
"method_to_operator", "hint_int", "SYMPY_INTERP",
]
SYM_FUNCTION_MODE = None
@ -104,22 +104,38 @@ def _handle_sym_dispatch(func, args, kwargs):
finally:
SYM_FUNCTION_MODE = mode
def hint_int(a):
if isinstance(a, torch.SymInt):
return a.node.require_hint()
assert type(a) is int, a
return a
def guard_scalar(a):
if isinstance(a, (SymBool, bool)):
return guard_bool(a)
elif isinstance(a, (SymInt, int)):
return guard_int(a)
elif isinstance(a, (SymFloat, float)):
return guard_float(a)
else:
raise AssertionError(f"unrecognized scalar {a}")
def guard_bool(a):
if isinstance(a, SymBool):
return a.node.guard_bool("", 0) # NB: uses Python backtrace
assert type(a) is bool
assert type(a) is bool, a
return a
def guard_int(a):
if isinstance(a, SymInt):
return a.node.guard_int("", 0) # NB: uses Python backtrace
assert type(a) is int
assert type(a) is int, a
return a
def guard_float(a):
if isinstance(a, SymFloat):
return a.node.guard_float("", 0) # NB: uses Python backtrace
assert isinstance(a, float)
assert isinstance(a, float), a
return a
# Drop in replacement for math.sqrt
@ -163,17 +179,67 @@ class SymNode:
This is a type erased SymInt/SymFloat which we use to do actual operations.
End users don't touch this. Magic methods are NOT defined on this object.
"""
def __init__(self, expr, shape_env, pytype, constant=None):
def __init__(self, expr, shape_env, pytype, hint: Optional[Union[int, float]], constant=None):
self._expr = expr
self.shape_env = shape_env
self.pytype = pytype
self.constant = constant
# What's the difference between hint and constant?
#
# - A constant is known to be invariant across invocations of the model;
# it will always be this value. We only really know this when we
# encounter an honest-to-goodness literal (when wrapping it into
# a SymNode, we set constant.) Most of the time, constant is None
#
# - A hint is a *particular* value from the particular run we are
# tracing, but it may vary the next time around. It's useful to
# keep this around, as if we need a concrete value from a SymNode,
# we will return the hint and guard on the expression that produced
# it giving the same hint next time around. The hint is not
# guaranteed to be set either: if you have an unbacked SymNode,
# there won't be any hint; it was the result of some tensor-dependent
# computation, but we don't know what it actually is because we
# haven't actually run the tensor computation.
#
# hint_expr is only set if we don't have a hint. When it is set, it
# contains the expression which contains the unbacked symnodes that,
# if constrained, would allow this expression to be hinted again.
if hint is None:
self._hint_expr = self.expr.xreplace(shape_env.var_to_val)
self._hint = None
self._update_hint() # check if the replacement actually was enough
else:
self._hint_expr = None
self._hint = hint
self.constant: Optional[Union[int, float, bool]] = constant
@property
def expr(self):
self._update_expr()
return self._expr
# Check if we have replacements hint_expr that would allow us to
# simplify it into a hint
def _update_hint(self):
if self._hint_expr.free_symbols <= self.shape_env.replacements.keys():
self._hint = self.pytype(self.shape_env.replace(self._hint_expr))
self._hint_expr = None
@property
def hint(self):
if self._hint is None:
self._update_hint()
return self._hint
def require_hint(self):
if self._hint is None:
self._update_hint()
if self._hint is None:
raise self.shape_env._make_data_dependent_error(self._hint_expr)
else:
return self._hint
else:
return self._hint
def _update_expr(self):
self._expr = self.shape_env.replace(self._expr)
@ -188,15 +254,15 @@ class SymNode:
def wrap_int(self, num):
assert type(num) is int
return SymNode(sympy.Integer(num), self.shape_env, int, constant=num)
return SymNode(sympy.Integer(num), self.shape_env, int, num, constant=num)
def wrap_float(self, num):
assert type(num) is float
return SymNode(sympy.Float(num), self.shape_env, float, constant=num)
return SymNode(sympy.Float(num), self.shape_env, float, num, constant=num)
def wrap_bool(self, num):
assert type(num) is bool
return SymNode(sympy.true if num else sympy.false, self.shape_env, bool, constant=num)
return SymNode(sympy.true if num else sympy.false, self.shape_env, bool, num, constant=num)
def clone(self):
return self
@ -240,7 +306,7 @@ class SymNode:
def guard_int(self, file, line):
# TODO: use the file/line for some useful diagnostic on why a
# guard occurred
r = self.shape_env.evaluate_expr(self.expr)
r = self.shape_env.evaluate_expr(self.expr, self.hint)
try:
return int(r)
except Exception:
@ -250,7 +316,7 @@ class SymNode:
def guard_float(self, file, line):
# TODO: use the file/line for some useful diagnostic on why a
# guard occurred
r = self.shape_env.evaluate_expr(self.expr)
r = self.shape_env.evaluate_expr(self.expr, self.hint)
try:
return float(r)
except Exception:
@ -261,7 +327,7 @@ class SymNode:
# TODO: use the file/line for some useful diagnostic on why a
# guard occurred
# TODO: why is the replace needed here?
r = self.shape_env.evaluate_expr(self.shape_env.replace(self.expr))
r = self.shape_env.evaluate_expr(self.shape_env.replace(self.expr), self.hint)
try:
return bool(r)
except Exception:
@ -564,6 +630,9 @@ def _make_node_magic(method, func):
log.warning(f"failed to eval {method}({expr}, {other_expr})")
raise
out = safe_expand(out)
out_hint = None
if self.hint is not None and other.hint is not None:
out_hint = op(self.hint, other.hint)
pytype: Type
# This is not strictly correct. In Python, a**b may return complex when
# a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This
@ -581,11 +650,11 @@ def _make_node_magic(method, func):
else:
pytype = self.pytype
return SymNode(out, self.shape_env, pytype)
return SymNode(out, self.shape_env, pytype, out_hint)
def unary_magic_impl(self):
if SYM_FUNCTION_MODE:
op = method_to_operator(method)
if SYM_FUNCTION_MODE:
r = _handle_sym_dispatch(op, (wrap_node(self),), {})
assert isinstance(r, SymTypes), type(r)
return r.node
@ -596,6 +665,9 @@ def _make_node_magic(method, func):
except Exception:
log.warning(f"failed to eval {method}({expr})")
raise
out_hint = None
if self.hint is not None:
out_hint = op(self.hint)
out = safe_expand(out)
pytype: Type
if method in always_int_magic_methods:
@ -605,7 +677,7 @@ def _make_node_magic(method, func):
else:
pytype = self.pytype
return SymNode(out, self.shape_env, pytype)
return SymNode(out, self.shape_env, pytype, out_hint)
if method in unary_magic_methods:
setattr(SymNode, method_attr, unary_magic_impl)
@ -628,8 +700,16 @@ def _make_node_sizes_strides(method, func):
except Exception:
log.warning(f"failed to eval {method}(*{size_exprs}, *{stride_exprs})")
raise
hints = []
out_hint = None
for s in itertools.chain(sizes, strides):
if s.hint is None:
break
hints.append(s.hint)
else:
out_hint = op(*hints)
# bool is never expandable
return SymNode(sympy.Eq(out, 1), self.shape_env, bool)
return SymNode(sympy.Eq(out, 1), self.shape_env, bool, out_hint)
setattr(SymNode, method, sizes_strides_impl)
@ -824,31 +904,34 @@ class ShapeEnv:
TensorPropertySource(source, TensorProperty.STRIDE, i)
)
assert all(x is not None for x in stride)
sym_size = [self.create_symintnode(i) for i in size]
sym_size = [self.create_symintnode(i, hint=hint) for i, hint in zip(size, ex.size())]
sym_stride = []
for i, stride_expr in enumerate(stride):
# NB: Don't duck size the stride; instead use the expression
# we computed
assert stride_expr is not None
sym_stride.append(self.create_symintnode(stride_expr))
sym_stride.append(self.create_symintnode(stride_expr, hint=ex.stride(i)))
sym_storage_offset = self.create_symintnode(self.create_symbol(
ex.storage_offset(),
TensorPropertySource(source, TensorProperty.STORAGE_OFFSET)
))
), hint=ex.storage_offset())
return sym_size, sym_stride, sym_storage_offset
def create_symintnode(self, sym: "sympy.Expr"):
return SymInt(SymNode(sym, self, int))
# If you know what the current hint value of the SymInt to be created
# is, pass it into hint. Otherwise, pass None and we will make our best
# guess
def create_symintnode(self, sym: "sympy.Expr", *, hint: Optional[int]):
return SymInt(SymNode(sym, self, int, hint))
def create_unbacked_symfloat(self):
symbol = Symbol(f"f{next(self.unbacked_symfloat_counter)}")
symbol.stack = ''.join(traceback.format_list(traceback.extract_stack()[:-1]))
return SymFloat(SymNode(symbol, self, float))
return SymFloat(SymNode(symbol, self, float, None))
def create_unbacked_symint(self):
symbol = Symbol(f"i{next(self.unbacked_symint_counter)}", integer=True)
symbol.stack = ''.join(traceback.format_list(traceback.extract_stack()[:-1]))
return SymInt(SymNode(symbol, self, int))
return SymInt(SymNode(symbol, self, int, None))
# This is guaranteed to return a symbol or its negation is a sympy.Symbol,
# but there may be a replacement that allows it to be immediately
@ -1217,12 +1300,12 @@ class ShapeEnv:
return self.replacements[a]
@lru_cache(256)
def _maybe_guard_eq(self, expr: Union["sympy.Eq", "sympy.Ne"]) -> None:
def _maybe_guard_eq(self, expr: Union["sympy.Eq", "sympy.Ne"], concrete_bool: bool) -> None:
"""
Evaluates the result of an eq call. If true, uses information to
simplify shapes (i.e. a == b or a % 5 == 0)
"""
concrete_bool = bool(self.size_hint(expr))
assert type(concrete_bool) is bool
if isinstance(expr, sympy.Eq):
if not concrete_bool:
return
@ -1266,7 +1349,7 @@ class ShapeEnv:
return
@lru_cache(256)
def evaluate_expr(self, expr: "sympy.Expr"):
def evaluate_expr(self, expr: "sympy.Expr", hint=None):
"""
Given an expression, evaluates it, adding guards if necessary
"""
@ -1277,13 +1360,17 @@ class ShapeEnv:
if static_expr is not None:
return static_expr
if hint is None:
concrete_val = self.size_hint(expr)
else:
concrete_val = sympy.sympify(hint)
if isinstance(expr, (sympy.Eq, sympy.Ne)):
self._maybe_guard_eq(expr)
self._maybe_guard_eq(expr, bool(concrete_val))
# TODO: If we successfully eliminate a symbol via equality, it
# is not actually necessary to save a guard for the equality,
# as we will implicitly generate a guard when we match that
# input against the symbol
concrete_val = self.size_hint(expr)
# TODO: optimize this; avoid formatting traces until we need them
# NB: drop two frames; evaluate_expr and the Sym* function that