mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Implement duck shaping on SymInts (#85808)
Duck shaping says that when two input tensors have the same size, we assume they are symbolically related. This follows the same optimization done by inductor. This optimization is not done completely because we don't currently install guards corresponding to the duck shape relationships we created, but overall the guard propagation for dynamic shape tracing is incomplete at the moment. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/85808 Approved by: https://github.com/albanD
This commit is contained in:
parent
3a3e2002d8
commit
ada6e5b53a
|
|
@ -806,8 +806,9 @@ class TestSymbolicTracing(TestCase):
|
|||
shape_env = self._test_dynamic(f, [(3, 4)], test_inputs)
|
||||
self.assertTrue(shape_env.evaluate_guards_for_args(torch.randn(4, 5)))
|
||||
self.assertFalse(shape_env.evaluate_guards_for_args(torch.randn(25, 5)))
|
||||
# one guard for size/stride contiguity, and one substantive guard
|
||||
assert len(shape_env.guards) == 2, "\n" + shape_env.format_guards()
|
||||
# TODO: There should eventually be guards for contiguity, but they're
|
||||
# not currently being done yet
|
||||
assert len(shape_env.guards) == 1, "\n" + shape_env.format_guards()
|
||||
|
||||
def test_binary_broadcast(self):
|
||||
def f(a, b):
|
||||
|
|
@ -902,14 +903,16 @@ def forward(self, a_1):
|
|||
def f(a, b):
|
||||
return a * b
|
||||
|
||||
fx_g = _trace(f, (5, 5), (5, 5))
|
||||
# NB: Numbers are carefully chosen to avoid duck shaping from applying
|
||||
|
||||
fx_g = _trace(f, (5, 6), (5, 6))
|
||||
self._assert_no_guards(fx_g, 2)
|
||||
|
||||
fx_g = _trace(f, (5, 5, 5), (5, 5, 5))
|
||||
fx_g = _trace(f, (5, 6, 7), (5, 6, 7))
|
||||
self._assert_no_guards(fx_g, 3)
|
||||
|
||||
fx_g = _trace(f, (5, 1), (1, 5))
|
||||
self._assert_no_guards(fx_g, 3)
|
||||
fx_g = _trace(f, (5, 1), (1, 6))
|
||||
self._assert_no_guards(fx_g, 2)
|
||||
|
||||
def f(a, b, c, d):
|
||||
a = a + b
|
||||
|
|
@ -936,7 +939,7 @@ def forward(self, a_1):
|
|||
fx_g = _trace(f, (4, 2), 8)
|
||||
self._assert_no_guards(fx_g, 2)
|
||||
|
||||
fx_g = _trace(f, (4, 2), (8, 4))
|
||||
fx_g = _trace(f, (4, 2), (8, 5))
|
||||
self._assert_no_guards(fx_g, 3)
|
||||
|
||||
fx_g = _trace(f, (2, 3, 4), 24)
|
||||
|
|
|
|||
|
|
@ -266,6 +266,9 @@ class ShapeEnv(object):
|
|||
self.replacements: Dict["sympy.Symbol", "sympy.Expr"] = {} #
|
||||
# Set holds a % b expressions that evaluate to 0.
|
||||
self.divisible: Set["sympy.Expr"] = set()
|
||||
# Duck-shaping says that if two input tensors have the same size,
|
||||
# they get assigned the same symbolic variable
|
||||
self.val_to_symint: Dict[int, torch.SymIntNode] = {}
|
||||
|
||||
def _get_key(self):
|
||||
"""
|
||||
|
|
@ -274,17 +277,27 @@ class ShapeEnv(object):
|
|||
"""
|
||||
return (len(self.replacements), len(self.divisible))
|
||||
|
||||
# NB: This is only called for input symbolic sizes; intermediate symbolic
|
||||
# sizes are allocated via a different mechanism
|
||||
def create_symint(self, name, val):
|
||||
assert val >= 0
|
||||
if not HAS_SYMPY:
|
||||
raise RuntimeError("Need sympy installed to create symbolic shapes")
|
||||
|
||||
# Currently we don't put 0/1 specialization in guards but perhaps we should
|
||||
# TODO: Put 0/1 specialization in guards
|
||||
if val == 0 or val == 1:
|
||||
return val
|
||||
# This implements duck-shaping: input sizes that match are assigned
|
||||
# the same symint
|
||||
# TODO: Create a guard whenever this happens
|
||||
# TODO: But how do I represent the guard in this case?
|
||||
if val in self.val_to_symint:
|
||||
return self.val_to_symint[val]
|
||||
sympy_expr = sympy.Symbol(name, positive=True, integer=True)
|
||||
py_sym_int = PySymInt(sympy_expr, self)
|
||||
cpp_sym_int = torch.SymIntNode.new_symint(py_sym_int) # type: ignore[attr-defined]
|
||||
self.var_to_val[sympy_expr] = sympy.Integer(val)
|
||||
self.val_to_symint[val] = cpp_sym_int
|
||||
return cpp_sym_int
|
||||
|
||||
def evaluate_guards_for_args(self, *args):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user