mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Track and record hint on SymNode and use when possible (#94201)
Historically, we work out `size_hint` by working it out on the fly by doing a substitution on the sympy expression with the `var_to_val` mapping. With this change, we also maintain the hint directly on SymNode (in `expr._hint`) and use it in lieu of Sympy substitution when it is available (mostly guards on SymInt, etc; in particular, in idiomatic Inductor code, we typically manipulate Sympy expressions directly and so do not have a way to conveniently maintain hints.) While it's possible this will give us modest performance improvements, this is not the point of this PR; the goal is to make it easier to carefully handle unbacked SymInts, where hints are expected not to be available. You can now easily test if a SymInt is backed or not by checking `symint.node.hint is None`. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/94201 Approved by: https://github.com/voznesenskym
This commit is contained in:
parent
b5ef37b9a4
commit
dc70b00d0b
|
|
@ -125,10 +125,11 @@ def create_symbolic_tensor(name, arg, shape_env):
|
|||
shape_env.create_symbolic_sizes_strides_storage_offset(arg, source=ConstantSource(name))
|
||||
return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, sym_storage_offset)
|
||||
|
||||
def create_symint(shape_env, i):
|
||||
def create_symint(shape_env, i: int):
|
||||
from torch._dynamo.source import ConstantSource
|
||||
return shape_env.create_symintnode(
|
||||
shape_env.create_symbol(i, source=ConstantSource(f"__testing_only{len(shape_env.var_to_val)}"))
|
||||
shape_env.create_symbol(i, source=ConstantSource(f"__testing_only{len(shape_env.var_to_val)}")),
|
||||
hint=i
|
||||
)
|
||||
|
||||
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
|
||||
|
|
@ -478,10 +479,7 @@ class TestSymNumberMagicMethods(TestCase):
|
|||
return torch.SymFloat(to_node(seed_node, inp))
|
||||
|
||||
def maybe_xfail(inp1, inp2):
|
||||
if fn == "sym_sqrt" and inp1 < 0 and type(inp1) in (SymFloat, SymInt):
|
||||
# TypeError: Cannot convert complex to float
|
||||
return self.assertRaises((TypeError,))
|
||||
elif fn == "sym_sqrt" and inp1 < 0:
|
||||
if fn == "sym_sqrt" and inp1 < 0:
|
||||
# ValueError: math domain error
|
||||
return self.assertRaises((ValueError,))
|
||||
elif fn in ("truediv", "floordiv", "mod") and inp2 == 0:
|
||||
|
|
|
|||
|
|
@ -413,6 +413,9 @@ def sym_max(a, b):
|
|||
if isinstance(a, (SymInt, SymFloat)):
|
||||
return a.__sym_max__(b)
|
||||
elif isinstance(b, (SymInt, SymFloat)):
|
||||
# NB: If you actually care about preserving output type exactly
|
||||
# if you do something like max(0, 0.0), it is NOT sound to treat
|
||||
# min/max as commutative
|
||||
return b.__sym_max__(a)
|
||||
return builtins.max(a, b) # type: ignore[operator]
|
||||
|
||||
|
|
|
|||
|
|
@ -687,7 +687,7 @@ class VariableBuilder:
|
|||
):
|
||||
shape_env = self.tx.output.shape_env
|
||||
wrapped_value = shape_env.create_symintnode(
|
||||
shape_env.create_symbol(value, source=self.source)
|
||||
shape_env.create_symbol(value, source=self.source), hint=value
|
||||
)
|
||||
self.tx.output.tracked_fakes.append(
|
||||
TrackedFake(wrapped_value, self.source)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from typing import Dict, List
|
|||
|
||||
import torch.fx
|
||||
import torch.random
|
||||
from torch.fx.experimental.symbolic_shapes import guard_scalar
|
||||
|
||||
from .. import config, variables
|
||||
from ..exc import unimplemented
|
||||
|
|
@ -460,9 +461,7 @@ class SymNodeVariable(VariableTracker):
|
|||
return self.proxy
|
||||
|
||||
def evaluate_expr(self, output_graph):
|
||||
if not isinstance(self.sym_num, torch.SymInt):
|
||||
return self.sym_num
|
||||
return output_graph.shape_env.evaluate_expr(self.sym_num.node.expr)
|
||||
return guard_scalar(self.sym_num)
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types
|
||||
from torch.fx.experimental.symbolic_shapes import hint_int
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
import operator
|
||||
|
|
@ -221,21 +222,14 @@ def _tensor_nbytes(numel, dtype):
|
|||
return numel * sizes[dtype]
|
||||
|
||||
def _size_of(node: fx.Node) -> int:
|
||||
def to_size_hint(s):
|
||||
if isinstance(s, torch.SymInt):
|
||||
py_s = s.node
|
||||
return py_s.shape_env.size_hint(py_s.expr)
|
||||
assert isinstance(s, int)
|
||||
return s
|
||||
|
||||
if 'val' in node.meta:
|
||||
val = node.meta['val']
|
||||
if isinstance(val, py_sym_types):
|
||||
return 1
|
||||
elif isinstance(val, (list, tuple)):
|
||||
return sum(_tensor_nbytes(to_size_hint(n.numel()), n.dtype) for n in val if isinstance(n, torch.Tensor))
|
||||
return sum(_tensor_nbytes(hint_int(n.numel()), n.dtype) for n in val if isinstance(n, torch.Tensor))
|
||||
elif isinstance(val, torch.Tensor):
|
||||
return _tensor_nbytes(to_size_hint(val.numel()), val.dtype)
|
||||
return _tensor_nbytes(hint_int(val.numel()), val.dtype)
|
||||
|
||||
raise RuntimeError(f"Unknown metadata type {type(val)}")
|
||||
|
||||
|
|
|
|||
|
|
@ -2482,7 +2482,7 @@ class ExternKernel(InputsKernel):
|
|||
tensor_args.append(arg)
|
||||
else:
|
||||
if isinstance(arg, sympy.Expr):
|
||||
arg = V.graph.sizevars.shape_env.create_symintnode(arg)
|
||||
arg = V.graph.sizevars.shape_env.create_symintnode(arg, hint=None)
|
||||
non_tensor_args.append(arg)
|
||||
|
||||
def unflatten_args(new_tensor_args, new_non_tensor_args):
|
||||
|
|
|
|||
|
|
@ -107,7 +107,7 @@ def convert_shape_to_symint(
|
|||
if isinstance(i, int)
|
||||
else int(i)
|
||||
if isinstance(i, sympy.Integer)
|
||||
else V.graph.sizevars.shape_env.create_symintnode(i)
|
||||
else V.graph.sizevars.shape_env.create_symintnode(i, hint=None)
|
||||
for i in lst
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -37,8 +37,8 @@ aten = torch._ops.ops.aten # type: ignore[has-type]
|
|||
|
||||
__all__ = [
|
||||
"has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv",
|
||||
"SymDispatchMode", "FloorDiv", "guard_int", "guard_float", "wrap_node",
|
||||
"method_to_operator", "SYMPY_INTERP",
|
||||
"SymDispatchMode", "FloorDiv", "guard_int", "guard_float", "guard_scalar", "wrap_node",
|
||||
"method_to_operator", "hint_int", "SYMPY_INTERP",
|
||||
]
|
||||
|
||||
SYM_FUNCTION_MODE = None
|
||||
|
|
@ -104,22 +104,38 @@ def _handle_sym_dispatch(func, args, kwargs):
|
|||
finally:
|
||||
SYM_FUNCTION_MODE = mode
|
||||
|
||||
def hint_int(a):
|
||||
if isinstance(a, torch.SymInt):
|
||||
return a.node.require_hint()
|
||||
assert type(a) is int, a
|
||||
return a
|
||||
|
||||
def guard_scalar(a):
|
||||
if isinstance(a, (SymBool, bool)):
|
||||
return guard_bool(a)
|
||||
elif isinstance(a, (SymInt, int)):
|
||||
return guard_int(a)
|
||||
elif isinstance(a, (SymFloat, float)):
|
||||
return guard_float(a)
|
||||
else:
|
||||
raise AssertionError(f"unrecognized scalar {a}")
|
||||
|
||||
def guard_bool(a):
|
||||
if isinstance(a, SymBool):
|
||||
return a.node.guard_bool("", 0) # NB: uses Python backtrace
|
||||
assert type(a) is bool
|
||||
assert type(a) is bool, a
|
||||
return a
|
||||
|
||||
def guard_int(a):
|
||||
if isinstance(a, SymInt):
|
||||
return a.node.guard_int("", 0) # NB: uses Python backtrace
|
||||
assert type(a) is int
|
||||
assert type(a) is int, a
|
||||
return a
|
||||
|
||||
def guard_float(a):
|
||||
if isinstance(a, SymFloat):
|
||||
return a.node.guard_float("", 0) # NB: uses Python backtrace
|
||||
assert isinstance(a, float)
|
||||
assert isinstance(a, float), a
|
||||
return a
|
||||
|
||||
# Drop in replacement for math.sqrt
|
||||
|
|
@ -163,17 +179,67 @@ class SymNode:
|
|||
This is a type erased SymInt/SymFloat which we use to do actual operations.
|
||||
End users don't touch this. Magic methods are NOT defined on this object.
|
||||
"""
|
||||
def __init__(self, expr, shape_env, pytype, constant=None):
|
||||
def __init__(self, expr, shape_env, pytype, hint: Optional[Union[int, float]], constant=None):
|
||||
self._expr = expr
|
||||
self.shape_env = shape_env
|
||||
self.pytype = pytype
|
||||
self.constant = constant
|
||||
# What's the difference between hint and constant?
|
||||
#
|
||||
# - A constant is known to be invariant across invocations of the model;
|
||||
# it will always be this value. We only really know this when we
|
||||
# encounter an honest-to-goodness literal (when wrapping it into
|
||||
# a SymNode, we set constant.) Most of the time, constant is None
|
||||
#
|
||||
# - A hint is a *particular* value from the particular run we are
|
||||
# tracing, but it may vary the next time around. It's useful to
|
||||
# keep this around, as if we need a concrete value from a SymNode,
|
||||
# we will return the hint and guard on the expression that produced
|
||||
# it giving the same hint next time around. The hint is not
|
||||
# guaranteed to be set either: if you have an unbacked SymNode,
|
||||
# there won't be any hint; it was the result of some tensor-dependent
|
||||
# computation, but we don't know what it actually is because we
|
||||
# haven't actually run the tensor computation.
|
||||
#
|
||||
# hint_expr is only set if we don't have a hint. When it is set, it
|
||||
# contains the expression which contains the unbacked symnodes that,
|
||||
# if constrained, would allow this expression to be hinted again.
|
||||
if hint is None:
|
||||
self._hint_expr = self.expr.xreplace(shape_env.var_to_val)
|
||||
self._hint = None
|
||||
self._update_hint() # check if the replacement actually was enough
|
||||
else:
|
||||
self._hint_expr = None
|
||||
self._hint = hint
|
||||
self.constant: Optional[Union[int, float, bool]] = constant
|
||||
|
||||
@property
|
||||
def expr(self):
|
||||
self._update_expr()
|
||||
return self._expr
|
||||
|
||||
# Check if we have replacements hint_expr that would allow us to
|
||||
# simplify it into a hint
|
||||
def _update_hint(self):
|
||||
if self._hint_expr.free_symbols <= self.shape_env.replacements.keys():
|
||||
self._hint = self.pytype(self.shape_env.replace(self._hint_expr))
|
||||
self._hint_expr = None
|
||||
|
||||
@property
|
||||
def hint(self):
|
||||
if self._hint is None:
|
||||
self._update_hint()
|
||||
return self._hint
|
||||
|
||||
def require_hint(self):
|
||||
if self._hint is None:
|
||||
self._update_hint()
|
||||
if self._hint is None:
|
||||
raise self.shape_env._make_data_dependent_error(self._hint_expr)
|
||||
else:
|
||||
return self._hint
|
||||
else:
|
||||
return self._hint
|
||||
|
||||
def _update_expr(self):
|
||||
self._expr = self.shape_env.replace(self._expr)
|
||||
|
||||
|
|
@ -188,15 +254,15 @@ class SymNode:
|
|||
|
||||
def wrap_int(self, num):
|
||||
assert type(num) is int
|
||||
return SymNode(sympy.Integer(num), self.shape_env, int, constant=num)
|
||||
return SymNode(sympy.Integer(num), self.shape_env, int, num, constant=num)
|
||||
|
||||
def wrap_float(self, num):
|
||||
assert type(num) is float
|
||||
return SymNode(sympy.Float(num), self.shape_env, float, constant=num)
|
||||
return SymNode(sympy.Float(num), self.shape_env, float, num, constant=num)
|
||||
|
||||
def wrap_bool(self, num):
|
||||
assert type(num) is bool
|
||||
return SymNode(sympy.true if num else sympy.false, self.shape_env, bool, constant=num)
|
||||
return SymNode(sympy.true if num else sympy.false, self.shape_env, bool, num, constant=num)
|
||||
|
||||
def clone(self):
|
||||
return self
|
||||
|
|
@ -240,7 +306,7 @@ class SymNode:
|
|||
def guard_int(self, file, line):
|
||||
# TODO: use the file/line for some useful diagnostic on why a
|
||||
# guard occurred
|
||||
r = self.shape_env.evaluate_expr(self.expr)
|
||||
r = self.shape_env.evaluate_expr(self.expr, self.hint)
|
||||
try:
|
||||
return int(r)
|
||||
except Exception:
|
||||
|
|
@ -250,7 +316,7 @@ class SymNode:
|
|||
def guard_float(self, file, line):
|
||||
# TODO: use the file/line for some useful diagnostic on why a
|
||||
# guard occurred
|
||||
r = self.shape_env.evaluate_expr(self.expr)
|
||||
r = self.shape_env.evaluate_expr(self.expr, self.hint)
|
||||
try:
|
||||
return float(r)
|
||||
except Exception:
|
||||
|
|
@ -261,7 +327,7 @@ class SymNode:
|
|||
# TODO: use the file/line for some useful diagnostic on why a
|
||||
# guard occurred
|
||||
# TODO: why is the replace needed here?
|
||||
r = self.shape_env.evaluate_expr(self.shape_env.replace(self.expr))
|
||||
r = self.shape_env.evaluate_expr(self.shape_env.replace(self.expr), self.hint)
|
||||
try:
|
||||
return bool(r)
|
||||
except Exception:
|
||||
|
|
@ -564,6 +630,9 @@ def _make_node_magic(method, func):
|
|||
log.warning(f"failed to eval {method}({expr}, {other_expr})")
|
||||
raise
|
||||
out = safe_expand(out)
|
||||
out_hint = None
|
||||
if self.hint is not None and other.hint is not None:
|
||||
out_hint = op(self.hint, other.hint)
|
||||
pytype: Type
|
||||
# This is not strictly correct. In Python, a**b may return complex when
|
||||
# a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This
|
||||
|
|
@ -581,11 +650,11 @@ def _make_node_magic(method, func):
|
|||
else:
|
||||
pytype = self.pytype
|
||||
|
||||
return SymNode(out, self.shape_env, pytype)
|
||||
return SymNode(out, self.shape_env, pytype, out_hint)
|
||||
|
||||
def unary_magic_impl(self):
|
||||
if SYM_FUNCTION_MODE:
|
||||
op = method_to_operator(method)
|
||||
if SYM_FUNCTION_MODE:
|
||||
r = _handle_sym_dispatch(op, (wrap_node(self),), {})
|
||||
assert isinstance(r, SymTypes), type(r)
|
||||
return r.node
|
||||
|
|
@ -596,6 +665,9 @@ def _make_node_magic(method, func):
|
|||
except Exception:
|
||||
log.warning(f"failed to eval {method}({expr})")
|
||||
raise
|
||||
out_hint = None
|
||||
if self.hint is not None:
|
||||
out_hint = op(self.hint)
|
||||
out = safe_expand(out)
|
||||
pytype: Type
|
||||
if method in always_int_magic_methods:
|
||||
|
|
@ -605,7 +677,7 @@ def _make_node_magic(method, func):
|
|||
else:
|
||||
pytype = self.pytype
|
||||
|
||||
return SymNode(out, self.shape_env, pytype)
|
||||
return SymNode(out, self.shape_env, pytype, out_hint)
|
||||
|
||||
if method in unary_magic_methods:
|
||||
setattr(SymNode, method_attr, unary_magic_impl)
|
||||
|
|
@ -628,8 +700,16 @@ def _make_node_sizes_strides(method, func):
|
|||
except Exception:
|
||||
log.warning(f"failed to eval {method}(*{size_exprs}, *{stride_exprs})")
|
||||
raise
|
||||
hints = []
|
||||
out_hint = None
|
||||
for s in itertools.chain(sizes, strides):
|
||||
if s.hint is None:
|
||||
break
|
||||
hints.append(s.hint)
|
||||
else:
|
||||
out_hint = op(*hints)
|
||||
# bool is never expandable
|
||||
return SymNode(sympy.Eq(out, 1), self.shape_env, bool)
|
||||
return SymNode(sympy.Eq(out, 1), self.shape_env, bool, out_hint)
|
||||
|
||||
setattr(SymNode, method, sizes_strides_impl)
|
||||
|
||||
|
|
@ -824,31 +904,34 @@ class ShapeEnv:
|
|||
TensorPropertySource(source, TensorProperty.STRIDE, i)
|
||||
)
|
||||
assert all(x is not None for x in stride)
|
||||
sym_size = [self.create_symintnode(i) for i in size]
|
||||
sym_size = [self.create_symintnode(i, hint=hint) for i, hint in zip(size, ex.size())]
|
||||
sym_stride = []
|
||||
for i, stride_expr in enumerate(stride):
|
||||
# NB: Don't duck size the stride; instead use the expression
|
||||
# we computed
|
||||
assert stride_expr is not None
|
||||
sym_stride.append(self.create_symintnode(stride_expr))
|
||||
sym_stride.append(self.create_symintnode(stride_expr, hint=ex.stride(i)))
|
||||
sym_storage_offset = self.create_symintnode(self.create_symbol(
|
||||
ex.storage_offset(),
|
||||
TensorPropertySource(source, TensorProperty.STORAGE_OFFSET)
|
||||
))
|
||||
), hint=ex.storage_offset())
|
||||
return sym_size, sym_stride, sym_storage_offset
|
||||
|
||||
def create_symintnode(self, sym: "sympy.Expr"):
|
||||
return SymInt(SymNode(sym, self, int))
|
||||
# If you know what the current hint value of the SymInt to be created
|
||||
# is, pass it into hint. Otherwise, pass None and we will make our best
|
||||
# guess
|
||||
def create_symintnode(self, sym: "sympy.Expr", *, hint: Optional[int]):
|
||||
return SymInt(SymNode(sym, self, int, hint))
|
||||
|
||||
def create_unbacked_symfloat(self):
|
||||
symbol = Symbol(f"f{next(self.unbacked_symfloat_counter)}")
|
||||
symbol.stack = ''.join(traceback.format_list(traceback.extract_stack()[:-1]))
|
||||
return SymFloat(SymNode(symbol, self, float))
|
||||
return SymFloat(SymNode(symbol, self, float, None))
|
||||
|
||||
def create_unbacked_symint(self):
|
||||
symbol = Symbol(f"i{next(self.unbacked_symint_counter)}", integer=True)
|
||||
symbol.stack = ''.join(traceback.format_list(traceback.extract_stack()[:-1]))
|
||||
return SymInt(SymNode(symbol, self, int))
|
||||
return SymInt(SymNode(symbol, self, int, None))
|
||||
|
||||
# This is guaranteed to return a symbol or its negation is a sympy.Symbol,
|
||||
# but there may be a replacement that allows it to be immediately
|
||||
|
|
@ -1217,12 +1300,12 @@ class ShapeEnv:
|
|||
return self.replacements[a]
|
||||
|
||||
@lru_cache(256)
|
||||
def _maybe_guard_eq(self, expr: Union["sympy.Eq", "sympy.Ne"]) -> None:
|
||||
def _maybe_guard_eq(self, expr: Union["sympy.Eq", "sympy.Ne"], concrete_bool: bool) -> None:
|
||||
"""
|
||||
Evaluates the result of an eq call. If true, uses information to
|
||||
simplify shapes (i.e. a == b or a % 5 == 0)
|
||||
"""
|
||||
concrete_bool = bool(self.size_hint(expr))
|
||||
assert type(concrete_bool) is bool
|
||||
if isinstance(expr, sympy.Eq):
|
||||
if not concrete_bool:
|
||||
return
|
||||
|
|
@ -1266,7 +1349,7 @@ class ShapeEnv:
|
|||
return
|
||||
|
||||
@lru_cache(256)
|
||||
def evaluate_expr(self, expr: "sympy.Expr"):
|
||||
def evaluate_expr(self, expr: "sympy.Expr", hint=None):
|
||||
"""
|
||||
Given an expression, evaluates it, adding guards if necessary
|
||||
"""
|
||||
|
|
@ -1277,13 +1360,17 @@ class ShapeEnv:
|
|||
if static_expr is not None:
|
||||
return static_expr
|
||||
|
||||
if hint is None:
|
||||
concrete_val = self.size_hint(expr)
|
||||
else:
|
||||
concrete_val = sympy.sympify(hint)
|
||||
|
||||
if isinstance(expr, (sympy.Eq, sympy.Ne)):
|
||||
self._maybe_guard_eq(expr)
|
||||
self._maybe_guard_eq(expr, bool(concrete_val))
|
||||
# TODO: If we successfully eliminate a symbol via equality, it
|
||||
# is not actually necessary to save a guard for the equality,
|
||||
# as we will implicitly generate a guard when we match that
|
||||
# input against the symbol
|
||||
concrete_val = self.size_hint(expr)
|
||||
|
||||
# TODO: optimize this; avoid formatting traces until we need them
|
||||
# NB: drop two frames; evaluate_expr and the Sym* function that
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user