Track and record hint on SymNode and use when possible (#94201)

Historically, we work out `size_hint` by working it out on the fly by doing a substitution on the sympy expression with the `var_to_val` mapping. With this change, we also maintain the hint directly on SymNode (in `expr._hint`) and use it in lieu of Sympy substitution when it is available (mostly guards on SymInt, etc; in particular, in idiomatic Inductor code, we typically manipulate Sympy expressions directly and so do not have a way to conveniently maintain hints.)

While it's possible this will give us modest performance improvements, this is not the point of this PR; the goal is to make it easier to carefully handle unbacked SymInts, where hints are expected not to be available. You can now easily test if a SymInt is backed or not by checking `symint.node.hint is None`.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94201
Approved by: https://github.com/voznesenskym
This commit is contained in:
Edward Z. Yang 2023-02-08 14:03:49 -05:00 committed by PyTorch MergeBot
parent b5ef37b9a4
commit dc70b00d0b
8 changed files with 131 additions and 50 deletions

View File

@ -125,10 +125,11 @@ def create_symbolic_tensor(name, arg, shape_env):
shape_env.create_symbolic_sizes_strides_storage_offset(arg, source=ConstantSource(name)) shape_env.create_symbolic_sizes_strides_storage_offset(arg, source=ConstantSource(name))
return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, sym_storage_offset) return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, sym_storage_offset)
def create_symint(shape_env, i): def create_symint(shape_env, i: int):
from torch._dynamo.source import ConstantSource from torch._dynamo.source import ConstantSource
return shape_env.create_symintnode( return shape_env.create_symintnode(
shape_env.create_symbol(i, source=ConstantSource(f"__testing_only{len(shape_env.var_to_val)}")) shape_env.create_symbol(i, source=ConstantSource(f"__testing_only{len(shape_env.var_to_val)}")),
hint=i
) )
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)") @skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
@ -478,10 +479,7 @@ class TestSymNumberMagicMethods(TestCase):
return torch.SymFloat(to_node(seed_node, inp)) return torch.SymFloat(to_node(seed_node, inp))
def maybe_xfail(inp1, inp2): def maybe_xfail(inp1, inp2):
if fn == "sym_sqrt" and inp1 < 0 and type(inp1) in (SymFloat, SymInt): if fn == "sym_sqrt" and inp1 < 0:
# TypeError: Cannot convert complex to float
return self.assertRaises((TypeError,))
elif fn == "sym_sqrt" and inp1 < 0:
# ValueError: math domain error # ValueError: math domain error
return self.assertRaises((ValueError,)) return self.assertRaises((ValueError,))
elif fn in ("truediv", "floordiv", "mod") and inp2 == 0: elif fn in ("truediv", "floordiv", "mod") and inp2 == 0:

View File

@ -413,6 +413,9 @@ def sym_max(a, b):
if isinstance(a, (SymInt, SymFloat)): if isinstance(a, (SymInt, SymFloat)):
return a.__sym_max__(b) return a.__sym_max__(b)
elif isinstance(b, (SymInt, SymFloat)): elif isinstance(b, (SymInt, SymFloat)):
# NB: If you actually care about preserving output type exactly
# if you do something like max(0, 0.0), it is NOT sound to treat
# min/max as commutative
return b.__sym_max__(a) return b.__sym_max__(a)
return builtins.max(a, b) # type: ignore[operator] return builtins.max(a, b) # type: ignore[operator]

View File

@ -687,7 +687,7 @@ class VariableBuilder:
): ):
shape_env = self.tx.output.shape_env shape_env = self.tx.output.shape_env
wrapped_value = shape_env.create_symintnode( wrapped_value = shape_env.create_symintnode(
shape_env.create_symbol(value, source=self.source) shape_env.create_symbol(value, source=self.source), hint=value
) )
self.tx.output.tracked_fakes.append( self.tx.output.tracked_fakes.append(
TrackedFake(wrapped_value, self.source) TrackedFake(wrapped_value, self.source)

View File

@ -6,6 +6,7 @@ from typing import Dict, List
import torch.fx import torch.fx
import torch.random import torch.random
from torch.fx.experimental.symbolic_shapes import guard_scalar
from .. import config, variables from .. import config, variables
from ..exc import unimplemented from ..exc import unimplemented
@ -460,9 +461,7 @@ class SymNodeVariable(VariableTracker):
return self.proxy return self.proxy
def evaluate_expr(self, output_graph): def evaluate_expr(self, output_graph):
if not isinstance(self.sym_num, torch.SymInt): return guard_scalar(self.sym_num)
return self.sym_num
return output_graph.shape_env.evaluate_expr(self.sym_num.node.expr)
def call_method( def call_method(
self, self,

View File

@ -1,4 +1,5 @@
from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types
from torch.fx.experimental.symbolic_shapes import hint_int
import torch import torch
import torch.fx as fx import torch.fx as fx
import operator import operator
@ -221,21 +222,14 @@ def _tensor_nbytes(numel, dtype):
return numel * sizes[dtype] return numel * sizes[dtype]
def _size_of(node: fx.Node) -> int: def _size_of(node: fx.Node) -> int:
def to_size_hint(s):
if isinstance(s, torch.SymInt):
py_s = s.node
return py_s.shape_env.size_hint(py_s.expr)
assert isinstance(s, int)
return s
if 'val' in node.meta: if 'val' in node.meta:
val = node.meta['val'] val = node.meta['val']
if isinstance(val, py_sym_types): if isinstance(val, py_sym_types):
return 1 return 1
elif isinstance(val, (list, tuple)): elif isinstance(val, (list, tuple)):
return sum(_tensor_nbytes(to_size_hint(n.numel()), n.dtype) for n in val if isinstance(n, torch.Tensor)) return sum(_tensor_nbytes(hint_int(n.numel()), n.dtype) for n in val if isinstance(n, torch.Tensor))
elif isinstance(val, torch.Tensor): elif isinstance(val, torch.Tensor):
return _tensor_nbytes(to_size_hint(val.numel()), val.dtype) return _tensor_nbytes(hint_int(val.numel()), val.dtype)
raise RuntimeError(f"Unknown metadata type {type(val)}") raise RuntimeError(f"Unknown metadata type {type(val)}")

View File

@ -2482,7 +2482,7 @@ class ExternKernel(InputsKernel):
tensor_args.append(arg) tensor_args.append(arg)
else: else:
if isinstance(arg, sympy.Expr): if isinstance(arg, sympy.Expr):
arg = V.graph.sizevars.shape_env.create_symintnode(arg) arg = V.graph.sizevars.shape_env.create_symintnode(arg, hint=None)
non_tensor_args.append(arg) non_tensor_args.append(arg)
def unflatten_args(new_tensor_args, new_non_tensor_args): def unflatten_args(new_tensor_args, new_non_tensor_args):

View File

@ -107,7 +107,7 @@ def convert_shape_to_symint(
if isinstance(i, int) if isinstance(i, int)
else int(i) else int(i)
if isinstance(i, sympy.Integer) if isinstance(i, sympy.Integer)
else V.graph.sizevars.shape_env.create_symintnode(i) else V.graph.sizevars.shape_env.create_symintnode(i, hint=None)
for i in lst for i in lst
] ]

View File

@ -37,8 +37,8 @@ aten = torch._ops.ops.aten # type: ignore[has-type]
__all__ = [ __all__ = [
"has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv", "has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv",
"SymDispatchMode", "FloorDiv", "guard_int", "guard_float", "wrap_node", "SymDispatchMode", "FloorDiv", "guard_int", "guard_float", "guard_scalar", "wrap_node",
"method_to_operator", "SYMPY_INTERP", "method_to_operator", "hint_int", "SYMPY_INTERP",
] ]
SYM_FUNCTION_MODE = None SYM_FUNCTION_MODE = None
@ -104,22 +104,38 @@ def _handle_sym_dispatch(func, args, kwargs):
finally: finally:
SYM_FUNCTION_MODE = mode SYM_FUNCTION_MODE = mode
def hint_int(a):
if isinstance(a, torch.SymInt):
return a.node.require_hint()
assert type(a) is int, a
return a
def guard_scalar(a):
if isinstance(a, (SymBool, bool)):
return guard_bool(a)
elif isinstance(a, (SymInt, int)):
return guard_int(a)
elif isinstance(a, (SymFloat, float)):
return guard_float(a)
else:
raise AssertionError(f"unrecognized scalar {a}")
def guard_bool(a): def guard_bool(a):
if isinstance(a, SymBool): if isinstance(a, SymBool):
return a.node.guard_bool("", 0) # NB: uses Python backtrace return a.node.guard_bool("", 0) # NB: uses Python backtrace
assert type(a) is bool assert type(a) is bool, a
return a return a
def guard_int(a): def guard_int(a):
if isinstance(a, SymInt): if isinstance(a, SymInt):
return a.node.guard_int("", 0) # NB: uses Python backtrace return a.node.guard_int("", 0) # NB: uses Python backtrace
assert type(a) is int assert type(a) is int, a
return a return a
def guard_float(a): def guard_float(a):
if isinstance(a, SymFloat): if isinstance(a, SymFloat):
return a.node.guard_float("", 0) # NB: uses Python backtrace return a.node.guard_float("", 0) # NB: uses Python backtrace
assert isinstance(a, float) assert isinstance(a, float), a
return a return a
# Drop in replacement for math.sqrt # Drop in replacement for math.sqrt
@ -163,17 +179,67 @@ class SymNode:
This is a type erased SymInt/SymFloat which we use to do actual operations. This is a type erased SymInt/SymFloat which we use to do actual operations.
End users don't touch this. Magic methods are NOT defined on this object. End users don't touch this. Magic methods are NOT defined on this object.
""" """
def __init__(self, expr, shape_env, pytype, constant=None): def __init__(self, expr, shape_env, pytype, hint: Optional[Union[int, float]], constant=None):
self._expr = expr self._expr = expr
self.shape_env = shape_env self.shape_env = shape_env
self.pytype = pytype self.pytype = pytype
self.constant = constant # What's the difference between hint and constant?
#
# - A constant is known to be invariant across invocations of the model;
# it will always be this value. We only really know this when we
# encounter an honest-to-goodness literal (when wrapping it into
# a SymNode, we set constant.) Most of the time, constant is None
#
# - A hint is a *particular* value from the particular run we are
# tracing, but it may vary the next time around. It's useful to
# keep this around, as if we need a concrete value from a SymNode,
# we will return the hint and guard on the expression that produced
# it giving the same hint next time around. The hint is not
# guaranteed to be set either: if you have an unbacked SymNode,
# there won't be any hint; it was the result of some tensor-dependent
# computation, but we don't know what it actually is because we
# haven't actually run the tensor computation.
#
# hint_expr is only set if we don't have a hint. When it is set, it
# contains the expression which contains the unbacked symnodes that,
# if constrained, would allow this expression to be hinted again.
if hint is None:
self._hint_expr = self.expr.xreplace(shape_env.var_to_val)
self._hint = None
self._update_hint() # check if the replacement actually was enough
else:
self._hint_expr = None
self._hint = hint
self.constant: Optional[Union[int, float, bool]] = constant
@property @property
def expr(self): def expr(self):
self._update_expr() self._update_expr()
return self._expr return self._expr
# Check if we have replacements hint_expr that would allow us to
# simplify it into a hint
def _update_hint(self):
if self._hint_expr.free_symbols <= self.shape_env.replacements.keys():
self._hint = self.pytype(self.shape_env.replace(self._hint_expr))
self._hint_expr = None
@property
def hint(self):
if self._hint is None:
self._update_hint()
return self._hint
def require_hint(self):
if self._hint is None:
self._update_hint()
if self._hint is None:
raise self.shape_env._make_data_dependent_error(self._hint_expr)
else:
return self._hint
else:
return self._hint
def _update_expr(self): def _update_expr(self):
self._expr = self.shape_env.replace(self._expr) self._expr = self.shape_env.replace(self._expr)
@ -188,15 +254,15 @@ class SymNode:
def wrap_int(self, num): def wrap_int(self, num):
assert type(num) is int assert type(num) is int
return SymNode(sympy.Integer(num), self.shape_env, int, constant=num) return SymNode(sympy.Integer(num), self.shape_env, int, num, constant=num)
def wrap_float(self, num): def wrap_float(self, num):
assert type(num) is float assert type(num) is float
return SymNode(sympy.Float(num), self.shape_env, float, constant=num) return SymNode(sympy.Float(num), self.shape_env, float, num, constant=num)
def wrap_bool(self, num): def wrap_bool(self, num):
assert type(num) is bool assert type(num) is bool
return SymNode(sympy.true if num else sympy.false, self.shape_env, bool, constant=num) return SymNode(sympy.true if num else sympy.false, self.shape_env, bool, num, constant=num)
def clone(self): def clone(self):
return self return self
@ -240,7 +306,7 @@ class SymNode:
def guard_int(self, file, line): def guard_int(self, file, line):
# TODO: use the file/line for some useful diagnostic on why a # TODO: use the file/line for some useful diagnostic on why a
# guard occurred # guard occurred
r = self.shape_env.evaluate_expr(self.expr) r = self.shape_env.evaluate_expr(self.expr, self.hint)
try: try:
return int(r) return int(r)
except Exception: except Exception:
@ -250,7 +316,7 @@ class SymNode:
def guard_float(self, file, line): def guard_float(self, file, line):
# TODO: use the file/line for some useful diagnostic on why a # TODO: use the file/line for some useful diagnostic on why a
# guard occurred # guard occurred
r = self.shape_env.evaluate_expr(self.expr) r = self.shape_env.evaluate_expr(self.expr, self.hint)
try: try:
return float(r) return float(r)
except Exception: except Exception:
@ -261,7 +327,7 @@ class SymNode:
# TODO: use the file/line for some useful diagnostic on why a # TODO: use the file/line for some useful diagnostic on why a
# guard occurred # guard occurred
# TODO: why is the replace needed here? # TODO: why is the replace needed here?
r = self.shape_env.evaluate_expr(self.shape_env.replace(self.expr)) r = self.shape_env.evaluate_expr(self.shape_env.replace(self.expr), self.hint)
try: try:
return bool(r) return bool(r)
except Exception: except Exception:
@ -564,6 +630,9 @@ def _make_node_magic(method, func):
log.warning(f"failed to eval {method}({expr}, {other_expr})") log.warning(f"failed to eval {method}({expr}, {other_expr})")
raise raise
out = safe_expand(out) out = safe_expand(out)
out_hint = None
if self.hint is not None and other.hint is not None:
out_hint = op(self.hint, other.hint)
pytype: Type pytype: Type
# This is not strictly correct. In Python, a**b may return complex when # This is not strictly correct. In Python, a**b may return complex when
# a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This # a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This
@ -581,11 +650,11 @@ def _make_node_magic(method, func):
else: else:
pytype = self.pytype pytype = self.pytype
return SymNode(out, self.shape_env, pytype) return SymNode(out, self.shape_env, pytype, out_hint)
def unary_magic_impl(self): def unary_magic_impl(self):
if SYM_FUNCTION_MODE:
op = method_to_operator(method) op = method_to_operator(method)
if SYM_FUNCTION_MODE:
r = _handle_sym_dispatch(op, (wrap_node(self),), {}) r = _handle_sym_dispatch(op, (wrap_node(self),), {})
assert isinstance(r, SymTypes), type(r) assert isinstance(r, SymTypes), type(r)
return r.node return r.node
@ -596,6 +665,9 @@ def _make_node_magic(method, func):
except Exception: except Exception:
log.warning(f"failed to eval {method}({expr})") log.warning(f"failed to eval {method}({expr})")
raise raise
out_hint = None
if self.hint is not None:
out_hint = op(self.hint)
out = safe_expand(out) out = safe_expand(out)
pytype: Type pytype: Type
if method in always_int_magic_methods: if method in always_int_magic_methods:
@ -605,7 +677,7 @@ def _make_node_magic(method, func):
else: else:
pytype = self.pytype pytype = self.pytype
return SymNode(out, self.shape_env, pytype) return SymNode(out, self.shape_env, pytype, out_hint)
if method in unary_magic_methods: if method in unary_magic_methods:
setattr(SymNode, method_attr, unary_magic_impl) setattr(SymNode, method_attr, unary_magic_impl)
@ -628,8 +700,16 @@ def _make_node_sizes_strides(method, func):
except Exception: except Exception:
log.warning(f"failed to eval {method}(*{size_exprs}, *{stride_exprs})") log.warning(f"failed to eval {method}(*{size_exprs}, *{stride_exprs})")
raise raise
hints = []
out_hint = None
for s in itertools.chain(sizes, strides):
if s.hint is None:
break
hints.append(s.hint)
else:
out_hint = op(*hints)
# bool is never expandable # bool is never expandable
return SymNode(sympy.Eq(out, 1), self.shape_env, bool) return SymNode(sympy.Eq(out, 1), self.shape_env, bool, out_hint)
setattr(SymNode, method, sizes_strides_impl) setattr(SymNode, method, sizes_strides_impl)
@ -824,31 +904,34 @@ class ShapeEnv:
TensorPropertySource(source, TensorProperty.STRIDE, i) TensorPropertySource(source, TensorProperty.STRIDE, i)
) )
assert all(x is not None for x in stride) assert all(x is not None for x in stride)
sym_size = [self.create_symintnode(i) for i in size] sym_size = [self.create_symintnode(i, hint=hint) for i, hint in zip(size, ex.size())]
sym_stride = [] sym_stride = []
for i, stride_expr in enumerate(stride): for i, stride_expr in enumerate(stride):
# NB: Don't duck size the stride; instead use the expression # NB: Don't duck size the stride; instead use the expression
# we computed # we computed
assert stride_expr is not None assert stride_expr is not None
sym_stride.append(self.create_symintnode(stride_expr)) sym_stride.append(self.create_symintnode(stride_expr, hint=ex.stride(i)))
sym_storage_offset = self.create_symintnode(self.create_symbol( sym_storage_offset = self.create_symintnode(self.create_symbol(
ex.storage_offset(), ex.storage_offset(),
TensorPropertySource(source, TensorProperty.STORAGE_OFFSET) TensorPropertySource(source, TensorProperty.STORAGE_OFFSET)
)) ), hint=ex.storage_offset())
return sym_size, sym_stride, sym_storage_offset return sym_size, sym_stride, sym_storage_offset
def create_symintnode(self, sym: "sympy.Expr"): # If you know what the current hint value of the SymInt to be created
return SymInt(SymNode(sym, self, int)) # is, pass it into hint. Otherwise, pass None and we will make our best
# guess
def create_symintnode(self, sym: "sympy.Expr", *, hint: Optional[int]):
return SymInt(SymNode(sym, self, int, hint))
def create_unbacked_symfloat(self): def create_unbacked_symfloat(self):
symbol = Symbol(f"f{next(self.unbacked_symfloat_counter)}") symbol = Symbol(f"f{next(self.unbacked_symfloat_counter)}")
symbol.stack = ''.join(traceback.format_list(traceback.extract_stack()[:-1])) symbol.stack = ''.join(traceback.format_list(traceback.extract_stack()[:-1]))
return SymFloat(SymNode(symbol, self, float)) return SymFloat(SymNode(symbol, self, float, None))
def create_unbacked_symint(self): def create_unbacked_symint(self):
symbol = Symbol(f"i{next(self.unbacked_symint_counter)}", integer=True) symbol = Symbol(f"i{next(self.unbacked_symint_counter)}", integer=True)
symbol.stack = ''.join(traceback.format_list(traceback.extract_stack()[:-1])) symbol.stack = ''.join(traceback.format_list(traceback.extract_stack()[:-1]))
return SymInt(SymNode(symbol, self, int)) return SymInt(SymNode(symbol, self, int, None))
# This is guaranteed to return a symbol or its negation is a sympy.Symbol, # This is guaranteed to return a symbol or its negation is a sympy.Symbol,
# but there may be a replacement that allows it to be immediately # but there may be a replacement that allows it to be immediately
@ -1217,12 +1300,12 @@ class ShapeEnv:
return self.replacements[a] return self.replacements[a]
@lru_cache(256) @lru_cache(256)
def _maybe_guard_eq(self, expr: Union["sympy.Eq", "sympy.Ne"]) -> None: def _maybe_guard_eq(self, expr: Union["sympy.Eq", "sympy.Ne"], concrete_bool: bool) -> None:
""" """
Evaluates the result of an eq call. If true, uses information to Evaluates the result of an eq call. If true, uses information to
simplify shapes (i.e. a == b or a % 5 == 0) simplify shapes (i.e. a == b or a % 5 == 0)
""" """
concrete_bool = bool(self.size_hint(expr)) assert type(concrete_bool) is bool
if isinstance(expr, sympy.Eq): if isinstance(expr, sympy.Eq):
if not concrete_bool: if not concrete_bool:
return return
@ -1266,7 +1349,7 @@ class ShapeEnv:
return return
@lru_cache(256) @lru_cache(256)
def evaluate_expr(self, expr: "sympy.Expr"): def evaluate_expr(self, expr: "sympy.Expr", hint=None):
""" """
Given an expression, evaluates it, adding guards if necessary Given an expression, evaluates it, adding guards if necessary
""" """
@ -1277,13 +1360,17 @@ class ShapeEnv:
if static_expr is not None: if static_expr is not None:
return static_expr return static_expr
if hint is None:
concrete_val = self.size_hint(expr)
else:
concrete_val = sympy.sympify(hint)
if isinstance(expr, (sympy.Eq, sympy.Ne)): if isinstance(expr, (sympy.Eq, sympy.Ne)):
self._maybe_guard_eq(expr) self._maybe_guard_eq(expr, bool(concrete_val))
# TODO: If we successfully eliminate a symbol via equality, it # TODO: If we successfully eliminate a symbol via equality, it
# is not actually necessary to save a guard for the equality, # is not actually necessary to save a guard for the equality,
# as we will implicitly generate a guard when we match that # as we will implicitly generate a guard when we match that
# input against the symbol # input against the symbol
concrete_val = self.size_hint(expr)
# TODO: optimize this; avoid formatting traces until we need them # TODO: optimize this; avoid formatting traces until we need them
# NB: drop two frames; evaluate_expr and the Sym* function that # NB: drop two frames; evaluate_expr and the Sym* function that