diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 5e7a5eca394..bc2858c56cc 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -125,10 +125,11 @@ def create_symbolic_tensor(name, arg, shape_env): shape_env.create_symbolic_sizes_strides_storage_offset(arg, source=ConstantSource(name)) return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, sym_storage_offset) -def create_symint(shape_env, i): +def create_symint(shape_env, i: int): from torch._dynamo.source import ConstantSource return shape_env.create_symintnode( - shape_env.create_symbol(i, source=ConstantSource(f"__testing_only{len(shape_env.var_to_val)}")) + shape_env.create_symbol(i, source=ConstantSource(f"__testing_only{len(shape_env.var_to_val)}")), + hint=i ) @skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)") @@ -478,10 +479,7 @@ class TestSymNumberMagicMethods(TestCase): return torch.SymFloat(to_node(seed_node, inp)) def maybe_xfail(inp1, inp2): - if fn == "sym_sqrt" and inp1 < 0 and type(inp1) in (SymFloat, SymInt): - # TypeError: Cannot convert complex to float - return self.assertRaises((TypeError,)) - elif fn == "sym_sqrt" and inp1 < 0: + if fn == "sym_sqrt" and inp1 < 0: # ValueError: math domain error return self.assertRaises((ValueError,)) elif fn in ("truediv", "floordiv", "mod") and inp2 == 0: diff --git a/torch/__init__.py b/torch/__init__.py index 72601100ee9..040d4bb2724 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -413,6 +413,9 @@ def sym_max(a, b): if isinstance(a, (SymInt, SymFloat)): return a.__sym_max__(b) elif isinstance(b, (SymInt, SymFloat)): + # NB: If you actually care about preserving output type exactly + # if you do something like max(0, 0.0), it is NOT sound to treat + # min/max as commutative return b.__sym_max__(a) return builtins.max(a, b) # type: ignore[operator] diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 6af600fba79..db4e9ef7b34 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -687,7 +687,7 @@ class VariableBuilder: ): shape_env = self.tx.output.shape_env wrapped_value = shape_env.create_symintnode( - shape_env.create_symbol(value, source=self.source) + shape_env.create_symbol(value, source=self.source), hint=value ) self.tx.output.tracked_fakes.append( TrackedFake(wrapped_value, self.source) diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 1f648f0f403..9e09f378ac8 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -6,6 +6,7 @@ from typing import Dict, List import torch.fx import torch.random +from torch.fx.experimental.symbolic_shapes import guard_scalar from .. import config, variables from ..exc import unimplemented @@ -460,9 +461,7 @@ class SymNodeVariable(VariableTracker): return self.proxy def evaluate_expr(self, output_graph): - if not isinstance(self.sym_num, torch.SymInt): - return self.sym_num - return output_graph.shape_env.evaluate_expr(self.sym_num.node.expr) + return guard_scalar(self.sym_num) def call_method( self, diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 0880e44ee79..63562895d41 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -1,4 +1,5 @@ from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types +from torch.fx.experimental.symbolic_shapes import hint_int import torch import torch.fx as fx import operator @@ -221,21 +222,14 @@ def _tensor_nbytes(numel, dtype): return numel * sizes[dtype] def _size_of(node: fx.Node) -> int: - def to_size_hint(s): - if isinstance(s, torch.SymInt): - py_s = s.node - return py_s.shape_env.size_hint(py_s.expr) - assert isinstance(s, int) - return s - if 'val' in node.meta: val = node.meta['val'] if isinstance(val, py_sym_types): return 1 elif isinstance(val, (list, tuple)): - return sum(_tensor_nbytes(to_size_hint(n.numel()), n.dtype) for n in val if isinstance(n, torch.Tensor)) + return sum(_tensor_nbytes(hint_int(n.numel()), n.dtype) for n in val if isinstance(n, torch.Tensor)) elif isinstance(val, torch.Tensor): - return _tensor_nbytes(to_size_hint(val.numel()), val.dtype) + return _tensor_nbytes(hint_int(val.numel()), val.dtype) raise RuntimeError(f"Unknown metadata type {type(val)}") diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index d081c69bb66..2cc1300d00b 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -2482,7 +2482,7 @@ class ExternKernel(InputsKernel): tensor_args.append(arg) else: if isinstance(arg, sympy.Expr): - arg = V.graph.sizevars.shape_env.create_symintnode(arg) + arg = V.graph.sizevars.shape_env.create_symintnode(arg, hint=None) non_tensor_args.append(arg) def unflatten_args(new_tensor_args, new_non_tensor_args): diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 84772964c58..7ad739f0168 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -107,7 +107,7 @@ def convert_shape_to_symint( if isinstance(i, int) else int(i) if isinstance(i, sympy.Integer) - else V.graph.sizevars.shape_env.create_symintnode(i) + else V.graph.sizevars.shape_env.create_symintnode(i, hint=None) for i in lst ] diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index a7e9099f19a..37205a3882f 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -37,8 +37,8 @@ aten = torch._ops.ops.aten # type: ignore[has-type] __all__ = [ "has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv", - "SymDispatchMode", "FloorDiv", "guard_int", "guard_float", "wrap_node", - "method_to_operator", "SYMPY_INTERP", + "SymDispatchMode", "FloorDiv", "guard_int", "guard_float", "guard_scalar", "wrap_node", + "method_to_operator", "hint_int", "SYMPY_INTERP", ] SYM_FUNCTION_MODE = None @@ -104,22 +104,38 @@ def _handle_sym_dispatch(func, args, kwargs): finally: SYM_FUNCTION_MODE = mode +def hint_int(a): + if isinstance(a, torch.SymInt): + return a.node.require_hint() + assert type(a) is int, a + return a + +def guard_scalar(a): + if isinstance(a, (SymBool, bool)): + return guard_bool(a) + elif isinstance(a, (SymInt, int)): + return guard_int(a) + elif isinstance(a, (SymFloat, float)): + return guard_float(a) + else: + raise AssertionError(f"unrecognized scalar {a}") + def guard_bool(a): if isinstance(a, SymBool): return a.node.guard_bool("", 0) # NB: uses Python backtrace - assert type(a) is bool + assert type(a) is bool, a return a def guard_int(a): if isinstance(a, SymInt): return a.node.guard_int("", 0) # NB: uses Python backtrace - assert type(a) is int + assert type(a) is int, a return a def guard_float(a): if isinstance(a, SymFloat): return a.node.guard_float("", 0) # NB: uses Python backtrace - assert isinstance(a, float) + assert isinstance(a, float), a return a # Drop in replacement for math.sqrt @@ -163,17 +179,67 @@ class SymNode: This is a type erased SymInt/SymFloat which we use to do actual operations. End users don't touch this. Magic methods are NOT defined on this object. """ - def __init__(self, expr, shape_env, pytype, constant=None): + def __init__(self, expr, shape_env, pytype, hint: Optional[Union[int, float]], constant=None): self._expr = expr self.shape_env = shape_env self.pytype = pytype - self.constant = constant + # What's the difference between hint and constant? + # + # - A constant is known to be invariant across invocations of the model; + # it will always be this value. We only really know this when we + # encounter an honest-to-goodness literal (when wrapping it into + # a SymNode, we set constant.) Most of the time, constant is None + # + # - A hint is a *particular* value from the particular run we are + # tracing, but it may vary the next time around. It's useful to + # keep this around, as if we need a concrete value from a SymNode, + # we will return the hint and guard on the expression that produced + # it giving the same hint next time around. The hint is not + # guaranteed to be set either: if you have an unbacked SymNode, + # there won't be any hint; it was the result of some tensor-dependent + # computation, but we don't know what it actually is because we + # haven't actually run the tensor computation. + # + # hint_expr is only set if we don't have a hint. When it is set, it + # contains the expression which contains the unbacked symnodes that, + # if constrained, would allow this expression to be hinted again. + if hint is None: + self._hint_expr = self.expr.xreplace(shape_env.var_to_val) + self._hint = None + self._update_hint() # check if the replacement actually was enough + else: + self._hint_expr = None + self._hint = hint + self.constant: Optional[Union[int, float, bool]] = constant @property def expr(self): self._update_expr() return self._expr + # Check if we have replacements hint_expr that would allow us to + # simplify it into a hint + def _update_hint(self): + if self._hint_expr.free_symbols <= self.shape_env.replacements.keys(): + self._hint = self.pytype(self.shape_env.replace(self._hint_expr)) + self._hint_expr = None + + @property + def hint(self): + if self._hint is None: + self._update_hint() + return self._hint + + def require_hint(self): + if self._hint is None: + self._update_hint() + if self._hint is None: + raise self.shape_env._make_data_dependent_error(self._hint_expr) + else: + return self._hint + else: + return self._hint + def _update_expr(self): self._expr = self.shape_env.replace(self._expr) @@ -188,15 +254,15 @@ class SymNode: def wrap_int(self, num): assert type(num) is int - return SymNode(sympy.Integer(num), self.shape_env, int, constant=num) + return SymNode(sympy.Integer(num), self.shape_env, int, num, constant=num) def wrap_float(self, num): assert type(num) is float - return SymNode(sympy.Float(num), self.shape_env, float, constant=num) + return SymNode(sympy.Float(num), self.shape_env, float, num, constant=num) def wrap_bool(self, num): assert type(num) is bool - return SymNode(sympy.true if num else sympy.false, self.shape_env, bool, constant=num) + return SymNode(sympy.true if num else sympy.false, self.shape_env, bool, num, constant=num) def clone(self): return self @@ -240,7 +306,7 @@ class SymNode: def guard_int(self, file, line): # TODO: use the file/line for some useful diagnostic on why a # guard occurred - r = self.shape_env.evaluate_expr(self.expr) + r = self.shape_env.evaluate_expr(self.expr, self.hint) try: return int(r) except Exception: @@ -250,7 +316,7 @@ class SymNode: def guard_float(self, file, line): # TODO: use the file/line for some useful diagnostic on why a # guard occurred - r = self.shape_env.evaluate_expr(self.expr) + r = self.shape_env.evaluate_expr(self.expr, self.hint) try: return float(r) except Exception: @@ -261,7 +327,7 @@ class SymNode: # TODO: use the file/line for some useful diagnostic on why a # guard occurred # TODO: why is the replace needed here? - r = self.shape_env.evaluate_expr(self.shape_env.replace(self.expr)) + r = self.shape_env.evaluate_expr(self.shape_env.replace(self.expr), self.hint) try: return bool(r) except Exception: @@ -564,6 +630,9 @@ def _make_node_magic(method, func): log.warning(f"failed to eval {method}({expr}, {other_expr})") raise out = safe_expand(out) + out_hint = None + if self.hint is not None and other.hint is not None: + out_hint = op(self.hint, other.hint) pytype: Type # This is not strictly correct. In Python, a**b may return complex when # a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This @@ -581,11 +650,11 @@ def _make_node_magic(method, func): else: pytype = self.pytype - return SymNode(out, self.shape_env, pytype) + return SymNode(out, self.shape_env, pytype, out_hint) def unary_magic_impl(self): + op = method_to_operator(method) if SYM_FUNCTION_MODE: - op = method_to_operator(method) r = _handle_sym_dispatch(op, (wrap_node(self),), {}) assert isinstance(r, SymTypes), type(r) return r.node @@ -596,6 +665,9 @@ def _make_node_magic(method, func): except Exception: log.warning(f"failed to eval {method}({expr})") raise + out_hint = None + if self.hint is not None: + out_hint = op(self.hint) out = safe_expand(out) pytype: Type if method in always_int_magic_methods: @@ -605,7 +677,7 @@ def _make_node_magic(method, func): else: pytype = self.pytype - return SymNode(out, self.shape_env, pytype) + return SymNode(out, self.shape_env, pytype, out_hint) if method in unary_magic_methods: setattr(SymNode, method_attr, unary_magic_impl) @@ -628,8 +700,16 @@ def _make_node_sizes_strides(method, func): except Exception: log.warning(f"failed to eval {method}(*{size_exprs}, *{stride_exprs})") raise + hints = [] + out_hint = None + for s in itertools.chain(sizes, strides): + if s.hint is None: + break + hints.append(s.hint) + else: + out_hint = op(*hints) # bool is never expandable - return SymNode(sympy.Eq(out, 1), self.shape_env, bool) + return SymNode(sympy.Eq(out, 1), self.shape_env, bool, out_hint) setattr(SymNode, method, sizes_strides_impl) @@ -824,31 +904,34 @@ class ShapeEnv: TensorPropertySource(source, TensorProperty.STRIDE, i) ) assert all(x is not None for x in stride) - sym_size = [self.create_symintnode(i) for i in size] + sym_size = [self.create_symintnode(i, hint=hint) for i, hint in zip(size, ex.size())] sym_stride = [] for i, stride_expr in enumerate(stride): # NB: Don't duck size the stride; instead use the expression # we computed assert stride_expr is not None - sym_stride.append(self.create_symintnode(stride_expr)) + sym_stride.append(self.create_symintnode(stride_expr, hint=ex.stride(i))) sym_storage_offset = self.create_symintnode(self.create_symbol( ex.storage_offset(), TensorPropertySource(source, TensorProperty.STORAGE_OFFSET) - )) + ), hint=ex.storage_offset()) return sym_size, sym_stride, sym_storage_offset - def create_symintnode(self, sym: "sympy.Expr"): - return SymInt(SymNode(sym, self, int)) + # If you know what the current hint value of the SymInt to be created + # is, pass it into hint. Otherwise, pass None and we will make our best + # guess + def create_symintnode(self, sym: "sympy.Expr", *, hint: Optional[int]): + return SymInt(SymNode(sym, self, int, hint)) def create_unbacked_symfloat(self): symbol = Symbol(f"f{next(self.unbacked_symfloat_counter)}") symbol.stack = ''.join(traceback.format_list(traceback.extract_stack()[:-1])) - return SymFloat(SymNode(symbol, self, float)) + return SymFloat(SymNode(symbol, self, float, None)) def create_unbacked_symint(self): symbol = Symbol(f"i{next(self.unbacked_symint_counter)}", integer=True) symbol.stack = ''.join(traceback.format_list(traceback.extract_stack()[:-1])) - return SymInt(SymNode(symbol, self, int)) + return SymInt(SymNode(symbol, self, int, None)) # This is guaranteed to return a symbol or its negation is a sympy.Symbol, # but there may be a replacement that allows it to be immediately @@ -1217,12 +1300,12 @@ class ShapeEnv: return self.replacements[a] @lru_cache(256) - def _maybe_guard_eq(self, expr: Union["sympy.Eq", "sympy.Ne"]) -> None: + def _maybe_guard_eq(self, expr: Union["sympy.Eq", "sympy.Ne"], concrete_bool: bool) -> None: """ Evaluates the result of an eq call. If true, uses information to simplify shapes (i.e. a == b or a % 5 == 0) """ - concrete_bool = bool(self.size_hint(expr)) + assert type(concrete_bool) is bool if isinstance(expr, sympy.Eq): if not concrete_bool: return @@ -1266,7 +1349,7 @@ class ShapeEnv: return @lru_cache(256) - def evaluate_expr(self, expr: "sympy.Expr"): + def evaluate_expr(self, expr: "sympy.Expr", hint=None): """ Given an expression, evaluates it, adding guards if necessary """ @@ -1277,13 +1360,17 @@ class ShapeEnv: if static_expr is not None: return static_expr + if hint is None: + concrete_val = self.size_hint(expr) + else: + concrete_val = sympy.sympify(hint) + if isinstance(expr, (sympy.Eq, sympy.Ne)): - self._maybe_guard_eq(expr) + self._maybe_guard_eq(expr, bool(concrete_val)) # TODO: If we successfully eliminate a symbol via equality, it # is not actually necessary to save a guard for the equality, # as we will implicitly generate a guard when we match that # input against the symbol - concrete_val = self.size_hint(expr) # TODO: optimize this; avoid formatting traces until we need them # NB: drop two frames; evaluate_expr and the Sym* function that