Implement duck shaping on SymInts (#85808)

Duck shaping says that when two input tensors have the same
size, we assume they are symbolically related.  This follows
the same optimization done by inductor.

This optimization is not done completely because we don't
currently install guards corresponding to the duck shape
relationships we created, but overall the guard propagation
for dynamic shape tracing is incomplete at the moment.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85808
Approved by: https://github.com/albanD
This commit is contained in:
Edward Z. Yang 2022-09-28 17:28:26 -04:00 committed by PyTorch MergeBot
parent 3a3e2002d8
commit ada6e5b53a
2 changed files with 24 additions and 8 deletions

View File

@ -806,8 +806,9 @@ class TestSymbolicTracing(TestCase):
shape_env = self._test_dynamic(f, [(3, 4)], test_inputs) shape_env = self._test_dynamic(f, [(3, 4)], test_inputs)
self.assertTrue(shape_env.evaluate_guards_for_args(torch.randn(4, 5))) self.assertTrue(shape_env.evaluate_guards_for_args(torch.randn(4, 5)))
self.assertFalse(shape_env.evaluate_guards_for_args(torch.randn(25, 5))) self.assertFalse(shape_env.evaluate_guards_for_args(torch.randn(25, 5)))
# one guard for size/stride contiguity, and one substantive guard # TODO: There should eventually be guards for contiguity, but they're
assert len(shape_env.guards) == 2, "\n" + shape_env.format_guards() # not currently being done yet
assert len(shape_env.guards) == 1, "\n" + shape_env.format_guards()
def test_binary_broadcast(self): def test_binary_broadcast(self):
def f(a, b): def f(a, b):
@ -902,14 +903,16 @@ def forward(self, a_1):
def f(a, b): def f(a, b):
return a * b return a * b
fx_g = _trace(f, (5, 5), (5, 5)) # NB: Numbers are carefully chosen to avoid duck shaping from applying
fx_g = _trace(f, (5, 6), (5, 6))
self._assert_no_guards(fx_g, 2) self._assert_no_guards(fx_g, 2)
fx_g = _trace(f, (5, 5, 5), (5, 5, 5)) fx_g = _trace(f, (5, 6, 7), (5, 6, 7))
self._assert_no_guards(fx_g, 3) self._assert_no_guards(fx_g, 3)
fx_g = _trace(f, (5, 1), (1, 5)) fx_g = _trace(f, (5, 1), (1, 6))
self._assert_no_guards(fx_g, 3) self._assert_no_guards(fx_g, 2)
def f(a, b, c, d): def f(a, b, c, d):
a = a + b a = a + b
@ -936,7 +939,7 @@ def forward(self, a_1):
fx_g = _trace(f, (4, 2), 8) fx_g = _trace(f, (4, 2), 8)
self._assert_no_guards(fx_g, 2) self._assert_no_guards(fx_g, 2)
fx_g = _trace(f, (4, 2), (8, 4)) fx_g = _trace(f, (4, 2), (8, 5))
self._assert_no_guards(fx_g, 3) self._assert_no_guards(fx_g, 3)
fx_g = _trace(f, (2, 3, 4), 24) fx_g = _trace(f, (2, 3, 4), 24)

View File

@ -266,6 +266,9 @@ class ShapeEnv(object):
self.replacements: Dict["sympy.Symbol", "sympy.Expr"] = {} # self.replacements: Dict["sympy.Symbol", "sympy.Expr"] = {} #
# Set holds a % b expressions that evaluate to 0. # Set holds a % b expressions that evaluate to 0.
self.divisible: Set["sympy.Expr"] = set() self.divisible: Set["sympy.Expr"] = set()
# Duck-shaping says that if two input tensors have the same size,
# they get assigned the same symbolic variable
self.val_to_symint: Dict[int, torch.SymIntNode] = {}
def _get_key(self): def _get_key(self):
""" """
@ -274,17 +277,27 @@ class ShapeEnv(object):
""" """
return (len(self.replacements), len(self.divisible)) return (len(self.replacements), len(self.divisible))
# NB: This is only called for input symbolic sizes; intermediate symbolic
# sizes are allocated via a different mechanism
def create_symint(self, name, val): def create_symint(self, name, val):
assert val >= 0
if not HAS_SYMPY: if not HAS_SYMPY:
raise RuntimeError("Need sympy installed to create symbolic shapes") raise RuntimeError("Need sympy installed to create symbolic shapes")
# Currently we don't put 0/1 specialization in guards but perhaps we should # TODO: Put 0/1 specialization in guards
if val == 0 or val == 1: if val == 0 or val == 1:
return val return val
# This implements duck-shaping: input sizes that match are assigned
# the same symint
# TODO: Create a guard whenever this happens
# TODO: But how do I represent the guard in this case?
if val in self.val_to_symint:
return self.val_to_symint[val]
sympy_expr = sympy.Symbol(name, positive=True, integer=True) sympy_expr = sympy.Symbol(name, positive=True, integer=True)
py_sym_int = PySymInt(sympy_expr, self) py_sym_int = PySymInt(sympy_expr, self)
cpp_sym_int = torch.SymIntNode.new_symint(py_sym_int) # type: ignore[attr-defined] cpp_sym_int = torch.SymIntNode.new_symint(py_sym_int) # type: ignore[attr-defined]
self.var_to_val[sympy_expr] = sympy.Integer(val) self.var_to_val[sympy_expr] = sympy.Integer(val)
self.val_to_symint[val] = cpp_sym_int
return cpp_sym_int return cpp_sym_int
def evaluate_guards_for_args(self, *args): def evaluate_guards_for_args(self, *args):