mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Track and record hint on SymNode and use when possible (#94201)
Historically, we work out `size_hint` by working it out on the fly by doing a substitution on the sympy expression with the `var_to_val` mapping. With this change, we also maintain the hint directly on SymNode (in `expr._hint`) and use it in lieu of Sympy substitution when it is available (mostly guards on SymInt, etc; in particular, in idiomatic Inductor code, we typically manipulate Sympy expressions directly and so do not have a way to conveniently maintain hints.) While it's possible this will give us modest performance improvements, this is not the point of this PR; the goal is to make it easier to carefully handle unbacked SymInts, where hints are expected not to be available. You can now easily test if a SymInt is backed or not by checking `symint.node.hint is None`. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/94201 Approved by: https://github.com/voznesenskym
This commit is contained in:
parent
b5ef37b9a4
commit
dc70b00d0b
|
|
@ -125,10 +125,11 @@ def create_symbolic_tensor(name, arg, shape_env):
|
||||||
shape_env.create_symbolic_sizes_strides_storage_offset(arg, source=ConstantSource(name))
|
shape_env.create_symbolic_sizes_strides_storage_offset(arg, source=ConstantSource(name))
|
||||||
return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, sym_storage_offset)
|
return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, sym_storage_offset)
|
||||||
|
|
||||||
def create_symint(shape_env, i):
|
def create_symint(shape_env, i: int):
|
||||||
from torch._dynamo.source import ConstantSource
|
from torch._dynamo.source import ConstantSource
|
||||||
return shape_env.create_symintnode(
|
return shape_env.create_symintnode(
|
||||||
shape_env.create_symbol(i, source=ConstantSource(f"__testing_only{len(shape_env.var_to_val)}"))
|
shape_env.create_symbol(i, source=ConstantSource(f"__testing_only{len(shape_env.var_to_val)}")),
|
||||||
|
hint=i
|
||||||
)
|
)
|
||||||
|
|
||||||
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
|
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
|
||||||
|
|
@ -478,10 +479,7 @@ class TestSymNumberMagicMethods(TestCase):
|
||||||
return torch.SymFloat(to_node(seed_node, inp))
|
return torch.SymFloat(to_node(seed_node, inp))
|
||||||
|
|
||||||
def maybe_xfail(inp1, inp2):
|
def maybe_xfail(inp1, inp2):
|
||||||
if fn == "sym_sqrt" and inp1 < 0 and type(inp1) in (SymFloat, SymInt):
|
if fn == "sym_sqrt" and inp1 < 0:
|
||||||
# TypeError: Cannot convert complex to float
|
|
||||||
return self.assertRaises((TypeError,))
|
|
||||||
elif fn == "sym_sqrt" and inp1 < 0:
|
|
||||||
# ValueError: math domain error
|
# ValueError: math domain error
|
||||||
return self.assertRaises((ValueError,))
|
return self.assertRaises((ValueError,))
|
||||||
elif fn in ("truediv", "floordiv", "mod") and inp2 == 0:
|
elif fn in ("truediv", "floordiv", "mod") and inp2 == 0:
|
||||||
|
|
|
||||||
|
|
@ -413,6 +413,9 @@ def sym_max(a, b):
|
||||||
if isinstance(a, (SymInt, SymFloat)):
|
if isinstance(a, (SymInt, SymFloat)):
|
||||||
return a.__sym_max__(b)
|
return a.__sym_max__(b)
|
||||||
elif isinstance(b, (SymInt, SymFloat)):
|
elif isinstance(b, (SymInt, SymFloat)):
|
||||||
|
# NB: If you actually care about preserving output type exactly
|
||||||
|
# if you do something like max(0, 0.0), it is NOT sound to treat
|
||||||
|
# min/max as commutative
|
||||||
return b.__sym_max__(a)
|
return b.__sym_max__(a)
|
||||||
return builtins.max(a, b) # type: ignore[operator]
|
return builtins.max(a, b) # type: ignore[operator]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -687,7 +687,7 @@ class VariableBuilder:
|
||||||
):
|
):
|
||||||
shape_env = self.tx.output.shape_env
|
shape_env = self.tx.output.shape_env
|
||||||
wrapped_value = shape_env.create_symintnode(
|
wrapped_value = shape_env.create_symintnode(
|
||||||
shape_env.create_symbol(value, source=self.source)
|
shape_env.create_symbol(value, source=self.source), hint=value
|
||||||
)
|
)
|
||||||
self.tx.output.tracked_fakes.append(
|
self.tx.output.tracked_fakes.append(
|
||||||
TrackedFake(wrapped_value, self.source)
|
TrackedFake(wrapped_value, self.source)
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from typing import Dict, List
|
||||||
|
|
||||||
import torch.fx
|
import torch.fx
|
||||||
import torch.random
|
import torch.random
|
||||||
|
from torch.fx.experimental.symbolic_shapes import guard_scalar
|
||||||
|
|
||||||
from .. import config, variables
|
from .. import config, variables
|
||||||
from ..exc import unimplemented
|
from ..exc import unimplemented
|
||||||
|
|
@ -460,9 +461,7 @@ class SymNodeVariable(VariableTracker):
|
||||||
return self.proxy
|
return self.proxy
|
||||||
|
|
||||||
def evaluate_expr(self, output_graph):
|
def evaluate_expr(self, output_graph):
|
||||||
if not isinstance(self.sym_num, torch.SymInt):
|
return guard_scalar(self.sym_num)
|
||||||
return self.sym_num
|
|
||||||
return output_graph.shape_env.evaluate_expr(self.sym_num.node.expr)
|
|
||||||
|
|
||||||
def call_method(
|
def call_method(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types
|
from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types
|
||||||
|
from torch.fx.experimental.symbolic_shapes import hint_int
|
||||||
import torch
|
import torch
|
||||||
import torch.fx as fx
|
import torch.fx as fx
|
||||||
import operator
|
import operator
|
||||||
|
|
@ -221,21 +222,14 @@ def _tensor_nbytes(numel, dtype):
|
||||||
return numel * sizes[dtype]
|
return numel * sizes[dtype]
|
||||||
|
|
||||||
def _size_of(node: fx.Node) -> int:
|
def _size_of(node: fx.Node) -> int:
|
||||||
def to_size_hint(s):
|
|
||||||
if isinstance(s, torch.SymInt):
|
|
||||||
py_s = s.node
|
|
||||||
return py_s.shape_env.size_hint(py_s.expr)
|
|
||||||
assert isinstance(s, int)
|
|
||||||
return s
|
|
||||||
|
|
||||||
if 'val' in node.meta:
|
if 'val' in node.meta:
|
||||||
val = node.meta['val']
|
val = node.meta['val']
|
||||||
if isinstance(val, py_sym_types):
|
if isinstance(val, py_sym_types):
|
||||||
return 1
|
return 1
|
||||||
elif isinstance(val, (list, tuple)):
|
elif isinstance(val, (list, tuple)):
|
||||||
return sum(_tensor_nbytes(to_size_hint(n.numel()), n.dtype) for n in val if isinstance(n, torch.Tensor))
|
return sum(_tensor_nbytes(hint_int(n.numel()), n.dtype) for n in val if isinstance(n, torch.Tensor))
|
||||||
elif isinstance(val, torch.Tensor):
|
elif isinstance(val, torch.Tensor):
|
||||||
return _tensor_nbytes(to_size_hint(val.numel()), val.dtype)
|
return _tensor_nbytes(hint_int(val.numel()), val.dtype)
|
||||||
|
|
||||||
raise RuntimeError(f"Unknown metadata type {type(val)}")
|
raise RuntimeError(f"Unknown metadata type {type(val)}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2482,7 +2482,7 @@ class ExternKernel(InputsKernel):
|
||||||
tensor_args.append(arg)
|
tensor_args.append(arg)
|
||||||
else:
|
else:
|
||||||
if isinstance(arg, sympy.Expr):
|
if isinstance(arg, sympy.Expr):
|
||||||
arg = V.graph.sizevars.shape_env.create_symintnode(arg)
|
arg = V.graph.sizevars.shape_env.create_symintnode(arg, hint=None)
|
||||||
non_tensor_args.append(arg)
|
non_tensor_args.append(arg)
|
||||||
|
|
||||||
def unflatten_args(new_tensor_args, new_non_tensor_args):
|
def unflatten_args(new_tensor_args, new_non_tensor_args):
|
||||||
|
|
|
||||||
|
|
@ -107,7 +107,7 @@ def convert_shape_to_symint(
|
||||||
if isinstance(i, int)
|
if isinstance(i, int)
|
||||||
else int(i)
|
else int(i)
|
||||||
if isinstance(i, sympy.Integer)
|
if isinstance(i, sympy.Integer)
|
||||||
else V.graph.sizevars.shape_env.create_symintnode(i)
|
else V.graph.sizevars.shape_env.create_symintnode(i, hint=None)
|
||||||
for i in lst
|
for i in lst
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -37,8 +37,8 @@ aten = torch._ops.ops.aten # type: ignore[has-type]
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv",
|
"has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv",
|
||||||
"SymDispatchMode", "FloorDiv", "guard_int", "guard_float", "wrap_node",
|
"SymDispatchMode", "FloorDiv", "guard_int", "guard_float", "guard_scalar", "wrap_node",
|
||||||
"method_to_operator", "SYMPY_INTERP",
|
"method_to_operator", "hint_int", "SYMPY_INTERP",
|
||||||
]
|
]
|
||||||
|
|
||||||
SYM_FUNCTION_MODE = None
|
SYM_FUNCTION_MODE = None
|
||||||
|
|
@ -104,22 +104,38 @@ def _handle_sym_dispatch(func, args, kwargs):
|
||||||
finally:
|
finally:
|
||||||
SYM_FUNCTION_MODE = mode
|
SYM_FUNCTION_MODE = mode
|
||||||
|
|
||||||
|
def hint_int(a):
|
||||||
|
if isinstance(a, torch.SymInt):
|
||||||
|
return a.node.require_hint()
|
||||||
|
assert type(a) is int, a
|
||||||
|
return a
|
||||||
|
|
||||||
|
def guard_scalar(a):
|
||||||
|
if isinstance(a, (SymBool, bool)):
|
||||||
|
return guard_bool(a)
|
||||||
|
elif isinstance(a, (SymInt, int)):
|
||||||
|
return guard_int(a)
|
||||||
|
elif isinstance(a, (SymFloat, float)):
|
||||||
|
return guard_float(a)
|
||||||
|
else:
|
||||||
|
raise AssertionError(f"unrecognized scalar {a}")
|
||||||
|
|
||||||
def guard_bool(a):
|
def guard_bool(a):
|
||||||
if isinstance(a, SymBool):
|
if isinstance(a, SymBool):
|
||||||
return a.node.guard_bool("", 0) # NB: uses Python backtrace
|
return a.node.guard_bool("", 0) # NB: uses Python backtrace
|
||||||
assert type(a) is bool
|
assert type(a) is bool, a
|
||||||
return a
|
return a
|
||||||
|
|
||||||
def guard_int(a):
|
def guard_int(a):
|
||||||
if isinstance(a, SymInt):
|
if isinstance(a, SymInt):
|
||||||
return a.node.guard_int("", 0) # NB: uses Python backtrace
|
return a.node.guard_int("", 0) # NB: uses Python backtrace
|
||||||
assert type(a) is int
|
assert type(a) is int, a
|
||||||
return a
|
return a
|
||||||
|
|
||||||
def guard_float(a):
|
def guard_float(a):
|
||||||
if isinstance(a, SymFloat):
|
if isinstance(a, SymFloat):
|
||||||
return a.node.guard_float("", 0) # NB: uses Python backtrace
|
return a.node.guard_float("", 0) # NB: uses Python backtrace
|
||||||
assert isinstance(a, float)
|
assert isinstance(a, float), a
|
||||||
return a
|
return a
|
||||||
|
|
||||||
# Drop in replacement for math.sqrt
|
# Drop in replacement for math.sqrt
|
||||||
|
|
@ -163,17 +179,67 @@ class SymNode:
|
||||||
This is a type erased SymInt/SymFloat which we use to do actual operations.
|
This is a type erased SymInt/SymFloat which we use to do actual operations.
|
||||||
End users don't touch this. Magic methods are NOT defined on this object.
|
End users don't touch this. Magic methods are NOT defined on this object.
|
||||||
"""
|
"""
|
||||||
def __init__(self, expr, shape_env, pytype, constant=None):
|
def __init__(self, expr, shape_env, pytype, hint: Optional[Union[int, float]], constant=None):
|
||||||
self._expr = expr
|
self._expr = expr
|
||||||
self.shape_env = shape_env
|
self.shape_env = shape_env
|
||||||
self.pytype = pytype
|
self.pytype = pytype
|
||||||
self.constant = constant
|
# What's the difference between hint and constant?
|
||||||
|
#
|
||||||
|
# - A constant is known to be invariant across invocations of the model;
|
||||||
|
# it will always be this value. We only really know this when we
|
||||||
|
# encounter an honest-to-goodness literal (when wrapping it into
|
||||||
|
# a SymNode, we set constant.) Most of the time, constant is None
|
||||||
|
#
|
||||||
|
# - A hint is a *particular* value from the particular run we are
|
||||||
|
# tracing, but it may vary the next time around. It's useful to
|
||||||
|
# keep this around, as if we need a concrete value from a SymNode,
|
||||||
|
# we will return the hint and guard on the expression that produced
|
||||||
|
# it giving the same hint next time around. The hint is not
|
||||||
|
# guaranteed to be set either: if you have an unbacked SymNode,
|
||||||
|
# there won't be any hint; it was the result of some tensor-dependent
|
||||||
|
# computation, but we don't know what it actually is because we
|
||||||
|
# haven't actually run the tensor computation.
|
||||||
|
#
|
||||||
|
# hint_expr is only set if we don't have a hint. When it is set, it
|
||||||
|
# contains the expression which contains the unbacked symnodes that,
|
||||||
|
# if constrained, would allow this expression to be hinted again.
|
||||||
|
if hint is None:
|
||||||
|
self._hint_expr = self.expr.xreplace(shape_env.var_to_val)
|
||||||
|
self._hint = None
|
||||||
|
self._update_hint() # check if the replacement actually was enough
|
||||||
|
else:
|
||||||
|
self._hint_expr = None
|
||||||
|
self._hint = hint
|
||||||
|
self.constant: Optional[Union[int, float, bool]] = constant
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def expr(self):
|
def expr(self):
|
||||||
self._update_expr()
|
self._update_expr()
|
||||||
return self._expr
|
return self._expr
|
||||||
|
|
||||||
|
# Check if we have replacements hint_expr that would allow us to
|
||||||
|
# simplify it into a hint
|
||||||
|
def _update_hint(self):
|
||||||
|
if self._hint_expr.free_symbols <= self.shape_env.replacements.keys():
|
||||||
|
self._hint = self.pytype(self.shape_env.replace(self._hint_expr))
|
||||||
|
self._hint_expr = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hint(self):
|
||||||
|
if self._hint is None:
|
||||||
|
self._update_hint()
|
||||||
|
return self._hint
|
||||||
|
|
||||||
|
def require_hint(self):
|
||||||
|
if self._hint is None:
|
||||||
|
self._update_hint()
|
||||||
|
if self._hint is None:
|
||||||
|
raise self.shape_env._make_data_dependent_error(self._hint_expr)
|
||||||
|
else:
|
||||||
|
return self._hint
|
||||||
|
else:
|
||||||
|
return self._hint
|
||||||
|
|
||||||
def _update_expr(self):
|
def _update_expr(self):
|
||||||
self._expr = self.shape_env.replace(self._expr)
|
self._expr = self.shape_env.replace(self._expr)
|
||||||
|
|
||||||
|
|
@ -188,15 +254,15 @@ class SymNode:
|
||||||
|
|
||||||
def wrap_int(self, num):
|
def wrap_int(self, num):
|
||||||
assert type(num) is int
|
assert type(num) is int
|
||||||
return SymNode(sympy.Integer(num), self.shape_env, int, constant=num)
|
return SymNode(sympy.Integer(num), self.shape_env, int, num, constant=num)
|
||||||
|
|
||||||
def wrap_float(self, num):
|
def wrap_float(self, num):
|
||||||
assert type(num) is float
|
assert type(num) is float
|
||||||
return SymNode(sympy.Float(num), self.shape_env, float, constant=num)
|
return SymNode(sympy.Float(num), self.shape_env, float, num, constant=num)
|
||||||
|
|
||||||
def wrap_bool(self, num):
|
def wrap_bool(self, num):
|
||||||
assert type(num) is bool
|
assert type(num) is bool
|
||||||
return SymNode(sympy.true if num else sympy.false, self.shape_env, bool, constant=num)
|
return SymNode(sympy.true if num else sympy.false, self.shape_env, bool, num, constant=num)
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
return self
|
return self
|
||||||
|
|
@ -240,7 +306,7 @@ class SymNode:
|
||||||
def guard_int(self, file, line):
|
def guard_int(self, file, line):
|
||||||
# TODO: use the file/line for some useful diagnostic on why a
|
# TODO: use the file/line for some useful diagnostic on why a
|
||||||
# guard occurred
|
# guard occurred
|
||||||
r = self.shape_env.evaluate_expr(self.expr)
|
r = self.shape_env.evaluate_expr(self.expr, self.hint)
|
||||||
try:
|
try:
|
||||||
return int(r)
|
return int(r)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -250,7 +316,7 @@ class SymNode:
|
||||||
def guard_float(self, file, line):
|
def guard_float(self, file, line):
|
||||||
# TODO: use the file/line for some useful diagnostic on why a
|
# TODO: use the file/line for some useful diagnostic on why a
|
||||||
# guard occurred
|
# guard occurred
|
||||||
r = self.shape_env.evaluate_expr(self.expr)
|
r = self.shape_env.evaluate_expr(self.expr, self.hint)
|
||||||
try:
|
try:
|
||||||
return float(r)
|
return float(r)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -261,7 +327,7 @@ class SymNode:
|
||||||
# TODO: use the file/line for some useful diagnostic on why a
|
# TODO: use the file/line for some useful diagnostic on why a
|
||||||
# guard occurred
|
# guard occurred
|
||||||
# TODO: why is the replace needed here?
|
# TODO: why is the replace needed here?
|
||||||
r = self.shape_env.evaluate_expr(self.shape_env.replace(self.expr))
|
r = self.shape_env.evaluate_expr(self.shape_env.replace(self.expr), self.hint)
|
||||||
try:
|
try:
|
||||||
return bool(r)
|
return bool(r)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -564,6 +630,9 @@ def _make_node_magic(method, func):
|
||||||
log.warning(f"failed to eval {method}({expr}, {other_expr})")
|
log.warning(f"failed to eval {method}({expr}, {other_expr})")
|
||||||
raise
|
raise
|
||||||
out = safe_expand(out)
|
out = safe_expand(out)
|
||||||
|
out_hint = None
|
||||||
|
if self.hint is not None and other.hint is not None:
|
||||||
|
out_hint = op(self.hint, other.hint)
|
||||||
pytype: Type
|
pytype: Type
|
||||||
# This is not strictly correct. In Python, a**b may return complex when
|
# This is not strictly correct. In Python, a**b may return complex when
|
||||||
# a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This
|
# a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This
|
||||||
|
|
@ -581,11 +650,11 @@ def _make_node_magic(method, func):
|
||||||
else:
|
else:
|
||||||
pytype = self.pytype
|
pytype = self.pytype
|
||||||
|
|
||||||
return SymNode(out, self.shape_env, pytype)
|
return SymNode(out, self.shape_env, pytype, out_hint)
|
||||||
|
|
||||||
def unary_magic_impl(self):
|
def unary_magic_impl(self):
|
||||||
if SYM_FUNCTION_MODE:
|
|
||||||
op = method_to_operator(method)
|
op = method_to_operator(method)
|
||||||
|
if SYM_FUNCTION_MODE:
|
||||||
r = _handle_sym_dispatch(op, (wrap_node(self),), {})
|
r = _handle_sym_dispatch(op, (wrap_node(self),), {})
|
||||||
assert isinstance(r, SymTypes), type(r)
|
assert isinstance(r, SymTypes), type(r)
|
||||||
return r.node
|
return r.node
|
||||||
|
|
@ -596,6 +665,9 @@ def _make_node_magic(method, func):
|
||||||
except Exception:
|
except Exception:
|
||||||
log.warning(f"failed to eval {method}({expr})")
|
log.warning(f"failed to eval {method}({expr})")
|
||||||
raise
|
raise
|
||||||
|
out_hint = None
|
||||||
|
if self.hint is not None:
|
||||||
|
out_hint = op(self.hint)
|
||||||
out = safe_expand(out)
|
out = safe_expand(out)
|
||||||
pytype: Type
|
pytype: Type
|
||||||
if method in always_int_magic_methods:
|
if method in always_int_magic_methods:
|
||||||
|
|
@ -605,7 +677,7 @@ def _make_node_magic(method, func):
|
||||||
else:
|
else:
|
||||||
pytype = self.pytype
|
pytype = self.pytype
|
||||||
|
|
||||||
return SymNode(out, self.shape_env, pytype)
|
return SymNode(out, self.shape_env, pytype, out_hint)
|
||||||
|
|
||||||
if method in unary_magic_methods:
|
if method in unary_magic_methods:
|
||||||
setattr(SymNode, method_attr, unary_magic_impl)
|
setattr(SymNode, method_attr, unary_magic_impl)
|
||||||
|
|
@ -628,8 +700,16 @@ def _make_node_sizes_strides(method, func):
|
||||||
except Exception:
|
except Exception:
|
||||||
log.warning(f"failed to eval {method}(*{size_exprs}, *{stride_exprs})")
|
log.warning(f"failed to eval {method}(*{size_exprs}, *{stride_exprs})")
|
||||||
raise
|
raise
|
||||||
|
hints = []
|
||||||
|
out_hint = None
|
||||||
|
for s in itertools.chain(sizes, strides):
|
||||||
|
if s.hint is None:
|
||||||
|
break
|
||||||
|
hints.append(s.hint)
|
||||||
|
else:
|
||||||
|
out_hint = op(*hints)
|
||||||
# bool is never expandable
|
# bool is never expandable
|
||||||
return SymNode(sympy.Eq(out, 1), self.shape_env, bool)
|
return SymNode(sympy.Eq(out, 1), self.shape_env, bool, out_hint)
|
||||||
|
|
||||||
setattr(SymNode, method, sizes_strides_impl)
|
setattr(SymNode, method, sizes_strides_impl)
|
||||||
|
|
||||||
|
|
@ -824,31 +904,34 @@ class ShapeEnv:
|
||||||
TensorPropertySource(source, TensorProperty.STRIDE, i)
|
TensorPropertySource(source, TensorProperty.STRIDE, i)
|
||||||
)
|
)
|
||||||
assert all(x is not None for x in stride)
|
assert all(x is not None for x in stride)
|
||||||
sym_size = [self.create_symintnode(i) for i in size]
|
sym_size = [self.create_symintnode(i, hint=hint) for i, hint in zip(size, ex.size())]
|
||||||
sym_stride = []
|
sym_stride = []
|
||||||
for i, stride_expr in enumerate(stride):
|
for i, stride_expr in enumerate(stride):
|
||||||
# NB: Don't duck size the stride; instead use the expression
|
# NB: Don't duck size the stride; instead use the expression
|
||||||
# we computed
|
# we computed
|
||||||
assert stride_expr is not None
|
assert stride_expr is not None
|
||||||
sym_stride.append(self.create_symintnode(stride_expr))
|
sym_stride.append(self.create_symintnode(stride_expr, hint=ex.stride(i)))
|
||||||
sym_storage_offset = self.create_symintnode(self.create_symbol(
|
sym_storage_offset = self.create_symintnode(self.create_symbol(
|
||||||
ex.storage_offset(),
|
ex.storage_offset(),
|
||||||
TensorPropertySource(source, TensorProperty.STORAGE_OFFSET)
|
TensorPropertySource(source, TensorProperty.STORAGE_OFFSET)
|
||||||
))
|
), hint=ex.storage_offset())
|
||||||
return sym_size, sym_stride, sym_storage_offset
|
return sym_size, sym_stride, sym_storage_offset
|
||||||
|
|
||||||
def create_symintnode(self, sym: "sympy.Expr"):
|
# If you know what the current hint value of the SymInt to be created
|
||||||
return SymInt(SymNode(sym, self, int))
|
# is, pass it into hint. Otherwise, pass None and we will make our best
|
||||||
|
# guess
|
||||||
|
def create_symintnode(self, sym: "sympy.Expr", *, hint: Optional[int]):
|
||||||
|
return SymInt(SymNode(sym, self, int, hint))
|
||||||
|
|
||||||
def create_unbacked_symfloat(self):
|
def create_unbacked_symfloat(self):
|
||||||
symbol = Symbol(f"f{next(self.unbacked_symfloat_counter)}")
|
symbol = Symbol(f"f{next(self.unbacked_symfloat_counter)}")
|
||||||
symbol.stack = ''.join(traceback.format_list(traceback.extract_stack()[:-1]))
|
symbol.stack = ''.join(traceback.format_list(traceback.extract_stack()[:-1]))
|
||||||
return SymFloat(SymNode(symbol, self, float))
|
return SymFloat(SymNode(symbol, self, float, None))
|
||||||
|
|
||||||
def create_unbacked_symint(self):
|
def create_unbacked_symint(self):
|
||||||
symbol = Symbol(f"i{next(self.unbacked_symint_counter)}", integer=True)
|
symbol = Symbol(f"i{next(self.unbacked_symint_counter)}", integer=True)
|
||||||
symbol.stack = ''.join(traceback.format_list(traceback.extract_stack()[:-1]))
|
symbol.stack = ''.join(traceback.format_list(traceback.extract_stack()[:-1]))
|
||||||
return SymInt(SymNode(symbol, self, int))
|
return SymInt(SymNode(symbol, self, int, None))
|
||||||
|
|
||||||
# This is guaranteed to return a symbol or its negation is a sympy.Symbol,
|
# This is guaranteed to return a symbol or its negation is a sympy.Symbol,
|
||||||
# but there may be a replacement that allows it to be immediately
|
# but there may be a replacement that allows it to be immediately
|
||||||
|
|
@ -1217,12 +1300,12 @@ class ShapeEnv:
|
||||||
return self.replacements[a]
|
return self.replacements[a]
|
||||||
|
|
||||||
@lru_cache(256)
|
@lru_cache(256)
|
||||||
def _maybe_guard_eq(self, expr: Union["sympy.Eq", "sympy.Ne"]) -> None:
|
def _maybe_guard_eq(self, expr: Union["sympy.Eq", "sympy.Ne"], concrete_bool: bool) -> None:
|
||||||
"""
|
"""
|
||||||
Evaluates the result of an eq call. If true, uses information to
|
Evaluates the result of an eq call. If true, uses information to
|
||||||
simplify shapes (i.e. a == b or a % 5 == 0)
|
simplify shapes (i.e. a == b or a % 5 == 0)
|
||||||
"""
|
"""
|
||||||
concrete_bool = bool(self.size_hint(expr))
|
assert type(concrete_bool) is bool
|
||||||
if isinstance(expr, sympy.Eq):
|
if isinstance(expr, sympy.Eq):
|
||||||
if not concrete_bool:
|
if not concrete_bool:
|
||||||
return
|
return
|
||||||
|
|
@ -1266,7 +1349,7 @@ class ShapeEnv:
|
||||||
return
|
return
|
||||||
|
|
||||||
@lru_cache(256)
|
@lru_cache(256)
|
||||||
def evaluate_expr(self, expr: "sympy.Expr"):
|
def evaluate_expr(self, expr: "sympy.Expr", hint=None):
|
||||||
"""
|
"""
|
||||||
Given an expression, evaluates it, adding guards if necessary
|
Given an expression, evaluates it, adding guards if necessary
|
||||||
"""
|
"""
|
||||||
|
|
@ -1277,13 +1360,17 @@ class ShapeEnv:
|
||||||
if static_expr is not None:
|
if static_expr is not None:
|
||||||
return static_expr
|
return static_expr
|
||||||
|
|
||||||
|
if hint is None:
|
||||||
|
concrete_val = self.size_hint(expr)
|
||||||
|
else:
|
||||||
|
concrete_val = sympy.sympify(hint)
|
||||||
|
|
||||||
if isinstance(expr, (sympy.Eq, sympy.Ne)):
|
if isinstance(expr, (sympy.Eq, sympy.Ne)):
|
||||||
self._maybe_guard_eq(expr)
|
self._maybe_guard_eq(expr, bool(concrete_val))
|
||||||
# TODO: If we successfully eliminate a symbol via equality, it
|
# TODO: If we successfully eliminate a symbol via equality, it
|
||||||
# is not actually necessary to save a guard for the equality,
|
# is not actually necessary to save a guard for the equality,
|
||||||
# as we will implicitly generate a guard when we match that
|
# as we will implicitly generate a guard when we match that
|
||||||
# input against the symbol
|
# input against the symbol
|
||||||
concrete_val = self.size_hint(expr)
|
|
||||||
|
|
||||||
# TODO: optimize this; avoid formatting traces until we need them
|
# TODO: optimize this; avoid formatting traces until we need them
|
||||||
# NB: drop two frames; evaluate_expr and the Sym* function that
|
# NB: drop two frames; evaluate_expr and the Sym* function that
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user