Use source hashing to generate consistent symbolic ids (#149665)

This PR was inspired by internal models that were cache missing due to PGO. At a high level the problem looks as follows

Run 1, Invocation 1: We do static compile, save some example values in PGO/automatic dynamic

Run 1, Invocation 2: We detect varying inputs, do dynamic compile, get a dynamic graph and save to PGO. Crucially what we save to PGO is actually a superset of what is actually dynamic. If we notice an input was varying, we mark it as dynamic in PGO even if later on that value gets specialized. When a value gets specialized, we actually remove the symbol from the graph. This results in an interesting conundrum where although we are producing the same isomorphic graph, PGO makes the second run cache miss. Let's see how....

Run 2, Invocation 1: We fetch the PGO, over-mark things as dynamic, get a fx graph, look it up in the cache and... whoops! cache miss! This is because of the aforementioned behavior where the PGO profile will cause us to over-allocate symbols. In practice this means we end up saving a graph in cache with symbols x:s1, y:s3 and on second attempt we cache miss with x:s1, y:s6 where symbols s3,s4,s5 were all optimistically marked dynamic by PGO and subsequently specialized.

We solve this problem by hashing the source names. This ensures somewhat stable assignment. To prevent catastrophic symbol collisions, we use linear probing to ensure no collisions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149665
Approved by: https://github.com/Mingming-Ding, https://github.com/laithsakka
This commit is contained in:
bobrenjc93 2025-03-27 17:53:11 -07:00 committed by PyTorch MergeBot
parent c49315e645
commit f649ee73ce
23 changed files with 521 additions and 443 deletions

View File

@ -196,6 +196,46 @@ class AOTAutogradCacheTests(InductorTestCase):
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
@inductor_config.patch("fx_graph_remote_cache", False)
@inductor_config.patch("fx_graph_cache", True)
@functorch_config.patch({"enable_autograd_cache": True})
def test_symbol_specialization(self):
"""
Verify the symbol specializations don't cause cache miss.
"""
def fn(x, y, z):
return (torch.randn(5) + x + y, z * torch.randn(1))
a = torch.rand(5)
torch._dynamo.maybe_mark_dynamic(a, 0)
b = torch.rand(5)
c = torch.randn(6)
torch._dynamo.maybe_mark_dynamic(c, 0)
compiled_fn = torch.compile(fn, backend="inductor")
# A first call should miss in the cache.
compiled_fn(a, b, c)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
# A second call should hit even if a new dimension is marked as dynamic
# that is later specialized as part of tracing.
a = torch.rand(5)
torch._dynamo.maybe_mark_dynamic(a, 0)
b = torch.rand(5)
torch._dynamo.maybe_mark_dynamic(b, 0)
c = torch.randn(6)
torch._dynamo.maybe_mark_dynamic(c, 0)
self._clear_dynamo_and_codecache()
compiled_fn(a, b, c)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
@functorch_config.patch({"enable_autograd_cache": True}) @functorch_config.patch({"enable_autograd_cache": True})
def test_aot_runtime_trace_joint(self): def test_aot_runtime_trace_joint(self):
@torch.compile(backend="inductor") @torch.compile(backend="inductor")

View File

@ -245,7 +245,7 @@ class GraphModule(torch.nn.Module):
actual, actual,
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(2)", L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s7)"): def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(2)", L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s45)"):
l_inputs_ = L_inputs_ l_inputs_ = L_inputs_
l_sizes_0_ = L_sizes_0_ l_sizes_0_ = L_sizes_0_
l_hooks_1_keywords_fn_keywords_obj_counter = L_hooks_1_keywords_fn_keywords_obj_counter l_hooks_1_keywords_fn_keywords_obj_counter = L_hooks_1_keywords_fn_keywords_obj_counter
@ -264,7 +264,7 @@ class GraphModule(torch.nn.Module):
copy_: "f32[2]" = new_grad_strided.copy_(aot0_tangents_1); copy_ = None copy_: "f32[2]" = new_grad_strided.copy_(aot0_tangents_1); copy_ = None
add: "Sym(s7 + 1)" = l_hooks_1_keywords_fn_keywords_obj_counter + 1; l_hooks_1_keywords_fn_keywords_obj_counter = None add: "Sym(s45 + 1)" = l_hooks_1_keywords_fn_keywords_obj_counter + 1; l_hooks_1_keywords_fn_keywords_obj_counter = None
result: "f32[2]" = aot0_tangents_1 * aot0_tangents_1; aot0_tangents_1 = None result: "f32[2]" = aot0_tangents_1 * aot0_tangents_1; aot0_tangents_1 = None

View File

@ -57,18 +57,18 @@ class ComptimeTests(torch._dynamo.test_case.TestCase):
self.assertExpectedInline( self.assertExpectedInline(
FILE.getvalue().strip(), FILE.getvalue().strip(),
"""\ """\
FakeTensor(..., size=(s0,)) FakeTensor(..., size=(s77,))
2 2
[FakeTensor(..., size=(s0,)), 2] [FakeTensor(..., size=(s77,)), 2]
(FakeTensor(..., size=(s0,)), 2) (FakeTensor(..., size=(s77,)), 2)
{'foo': FakeTensor(..., size=(s0,))} {'foo': FakeTensor(..., size=(s77,))}
range(1, 3, 1) range(1, 3, 1)
Employee(name='foo', id=2) Employee(name='foo', id=2)
UserDefinedListVariable(mylist) UserDefinedListVariable(mylist)
defaultdict(NestedUserFunctionVariable(), {}) defaultdict(NestedUserFunctionVariable(), {})
set() set()
{'a','b'} {'a','b'}
s0""", s77""",
) )
def test_print_graph(self): def test_print_graph(self):

View File

@ -256,34 +256,34 @@ Model:
==> L['x'].size()[0]: 3 ==> L['x'].size()[0]: 3
==> L['x'].storage_offset(): 0 ==> L['x'].storage_offset(): 0
==> L['x'].stride()[0]: 1 ==> L['x'].stride()[0]: 1
==> s0: 3
==> s1: 0
==> s2: 1
==> s3: 1 ==> s3: 1
==> s52: 1
==> s77: 3
==> s86: 0
Assertions: Assertions:
==> (== 0 L['x'].storage_offset()) ==> (== 0 L['x'].storage_offset())
==> (== 1 L['x'].stride()[0]) ==> (== 1 L['x'].stride()[0])
==> (== L['shape'][0] s1) ==> (== L['shape'][0] s86)
==> (== L['shape'][1] s2) ==> (== L['shape'][1] s52)
==> (== L['shape'][2] s3) ==> (== L['shape'][2] s3)
==> (== L['x'].size()[0] s0) ==> (== L['x'].size()[0] s77)
==> (> s0 1) ==> (> s77 1)
Target Expressions: Target Expressions:
==> (!= (+ s1 s2 s3) s0) ==> (!= (+ s3 s52 s86) s77)
==> (<= 0 s1)
==> (<= 0 s2)
==> (<= 0 s3) ==> (<= 0 s3)
==> (<= 2 s0) ==> (<= 0 s52)
==> (<= 0 s86)
==> (<= 2 s77)
==> (== 0 L['x'].storage_offset()) ==> (== 0 L['x'].storage_offset())
==> (== 1 L['x'].stride()[0]) ==> (== 1 L['x'].stride()[0])
==> (== L['shape'][0] s1) ==> (== L['shape'][0] s86)
==> (== L['shape'][1] s2) ==> (== L['shape'][1] s52)
==> (== L['shape'][2] s3) ==> (== L['shape'][2] s3)
==> (== L['x'].size()[0] s0) ==> (== L['x'].size()[0] s77)
==> (> s0 0) ==> (> s77 0)
==> (>= 0 s1) ==> (>= 0 s86)
Failed Source Expressions: Failed Source Expressions:
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""", ==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
@ -309,7 +309,7 @@ Failed Source Expressions:
BisectValidationException, BisectValidationException,
lambda: fn(torch.randn(20), (5, 10, 5)), lambda: fn(torch.randn(20), (5, 10, 5)),
"""\ """\
translation validation failed when evaluating: Eq(s1 + s2 + s3, s0) translation validation failed when evaluating: Eq(s3 + s52 + s86, s77)
Failure occurred while running node: Failure occurred while running node:
%split : [num_users=3] = call_method[target=split](args = (%l_x_, (%l_shape_0_, %l_shape_1_, %l_shape_2_)), kwargs = {}) %split : [num_users=3] = call_method[target=split](args = (%l_x_, (%l_shape_0_, %l_shape_1_, %l_shape_2_)), kwargs = {})
@ -321,33 +321,33 @@ Model:
==> L['x'].size()[0]: 3 ==> L['x'].size()[0]: 3
==> L['x'].storage_offset(): 0 ==> L['x'].storage_offset(): 0
==> L['x'].stride()[0]: 1 ==> L['x'].stride()[0]: 1
==> s0: 3
==> s1: 1
==> s2: 1
==> s3: 0 ==> s3: 0
==> s52: 1
==> s77: 3
==> s86: 1
Assertions: Assertions:
==> (== 0 L['x'].storage_offset()) ==> (== 0 L['x'].storage_offset())
==> (== 1 L['x'].stride()[0]) ==> (== 1 L['x'].stride()[0])
==> (== L['shape'][0] s1) ==> (== L['shape'][0] s86)
==> (== L['shape'][1] s2) ==> (== L['shape'][1] s52)
==> (== L['shape'][2] s3) ==> (== L['shape'][2] s3)
==> (== L['x'].size()[0] s0) ==> (== L['x'].size()[0] s77)
==> (> s0 1) ==> (> s77 1)
Target Expressions: Target Expressions:
==> (!= (+ s1 s2 s3) s0) ==> (!= (+ s3 s52 s86) s77)
==> (<= 0 s1)
==> (<= 0 s2)
==> (<= 0 s3) ==> (<= 0 s3)
==> (<= 2 s0) ==> (<= 0 s52)
==> (<= 0 s86)
==> (<= 2 s77)
==> (== 0 L['x'].storage_offset()) ==> (== 0 L['x'].storage_offset())
==> (== 1 L['x'].stride()[0]) ==> (== 1 L['x'].stride()[0])
==> (== L['shape'][0] s1) ==> (== L['shape'][0] s86)
==> (== L['shape'][1] s2) ==> (== L['shape'][1] s52)
==> (== L['shape'][2] s3) ==> (== L['shape'][2] s3)
==> (== L['x'].size()[0] s0) ==> (== L['x'].size()[0] s77)
==> (> s0 0) ==> (> s77 0)
Failed Source Expressions: Failed Source Expressions:
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""", ==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",

View File

@ -2703,7 +2703,7 @@ def forward(self, x):
for node in ebar.graph_module.graph.nodes for node in ebar.graph_module.graph.nodes
if node.op == "placeholder" if node.op == "placeholder"
], ],
["torch.Size([s0, s1, s1])", "torch.Size([s0, s1, s1])"], ["torch.Size([s17, s27, s27])", "torch.Size([s17, s27, s27])"],
) )
@torch._dynamo.config.patch( @torch._dynamo.config.patch(
@ -3480,23 +3480,23 @@ def forward(self, x):
true_graph = """\ true_graph = """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, pred, x): def forward(self, pred, x):
arg1: "f32[s1, s2]"; arg1: "f32[s77, s27]";
arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec) arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
l_x_ = arg1 l_x_ = arg1
sin: "f32[s1, s2]" = l_x_.sin(); l_x_ = None sin: "f32[s77, s27]" = l_x_.sin(); l_x_ = None
return pytree.tree_unflatten([sin], self._out_spec) return pytree.tree_unflatten([sin], self._out_spec)
""" """
false_graph = """\ false_graph = """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, pred, x): def forward(self, pred, x):
arg1: "f32[s1, s2]"; arg1: "f32[s77, s27]";
arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec) arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
l_x_ = arg1 l_x_ = arg1
cos: "f32[s1, s2]" = l_x_.cos(); l_x_ = None cos: "f32[s77, s27]" = l_x_.cos(); l_x_ = None
return pytree.tree_unflatten([cos], self._out_spec) return pytree.tree_unflatten([cos], self._out_spec)
""" """
true_guard_code = [ true_guard_code = [

View File

@ -2655,7 +2655,7 @@ class GraphModule(torch.nn.Module):
normalize_gm(backend.graphs[0].print_readable(print_output=False)), normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0]"): def forward(self, s77: "Sym(s77)", L_x_: "f32[s77]"):
l_x_ = L_x_ l_x_ = L_x_
sum_1: "f32[]" = l_x_.sum(); l_x_ = None sum_1: "f32[]" = l_x_.sum(); l_x_ = None
@ -2885,13 +2885,13 @@ class GraphModule(torch.nn.Module):
normalize_gm(backend.graphs[0].print_readable(print_output=False)), normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"): def forward(self, s9: "Sym(s9)", L_lambda0_keywords_y_: "f32[s9, s9]"):
l_lambda0_keywords_y_ = L_lambda0_keywords_y_ l_lambda0_keywords_y_ = L_lambda0_keywords_y_
mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_ mul: "f32[s9, s9]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
mul_1: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None mul_1: "f32[s9, s9]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
mul_2: "f32[s0, s0]" = torch.mul(mul, mul_1); mul = mul_1 = None mul_2: "f32[s9, s9]" = torch.mul(mul, mul_1); mul = mul_1 = None
return (mul_2,) return (mul_2,)
""", """,
) )
@ -2932,14 +2932,14 @@ class GraphModule(torch.nn.Module):
normalize_gm(backend.graphs[0].print_readable(print_output=False)), normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"): def forward(self, s9: "Sym(s9)", L_lambda0_keywords_y_: "f32[s9, s9]"):
l_lambda0_keywords_y_ = L_lambda0_keywords_y_ l_lambda0_keywords_y_ = L_lambda0_keywords_y_
mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_ mul: "f32[s9, s9]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
add: "f32[s0, s0]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None add: "f32[s9, s9]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
mul_1: "f32[s0, s0]" = torch.mul(mul, add); mul = add = None mul_1: "f32[s9, s9]" = torch.mul(mul, add); mul = add = None
return (mul_1,) return (mul_1,)
""", """,
) )
@ -2982,14 +2982,14 @@ class GraphModule(torch.nn.Module):
normalize_gm(backend.graphs[0].print_readable(print_output=False)), normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"): def forward(self, s9: "Sym(s9)", L_lambda0_keywords_y_: "f32[s9, s9]"):
l_lambda0_keywords_y_ = L_lambda0_keywords_y_ l_lambda0_keywords_y_ = L_lambda0_keywords_y_
mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_ mul: "f32[s9, s9]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
add: "f32[s0, s0]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None add: "f32[s9, s9]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
mul_1: "f32[s0, s0]" = torch.mul(mul, add); mul = add = None mul_1: "f32[s9, s9]" = torch.mul(mul, add); mul = add = None
return (mul_1,) return (mul_1,)
""", """,
) )
@ -3029,14 +3029,14 @@ class GraphModule(torch.nn.Module):
normalize_gm(backend.graphs[0].print_readable(print_output=False)), normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0, s0]"): def forward(self, s77: "Sym(s77)", L_x_: "f32[s77, s77]"):
l_x_ = L_x_ l_x_ = L_x_
mul: "f32[s0, s0]" = l_x_ * 4 mul: "f32[s77, s77]" = l_x_ * 4
mul_1: "f32[s0, s0]" = mul * l_x_; mul = None mul_1: "f32[s77, s77]" = mul * l_x_; mul = None
mul_2: "f32[s0, s0]" = 20 * l_x_; l_x_ = None mul_2: "f32[s77, s77]" = 20 * l_x_; l_x_ = None
mul_3: "f32[s0, s0]" = torch.mul(mul_1, mul_2); mul_1 = mul_2 = None mul_3: "f32[s77, s77]" = torch.mul(mul_1, mul_2); mul_1 = mul_2 = None
return (mul_3,) return (mul_3,)
""", """,
) )

View File

@ -413,18 +413,18 @@ class GraphModule(torch.nn.Module):
actual_graph, actual_graph,
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0, 1]"): def forward(self, s77: "Sym(s77)", L_x_: "f32[s77, 1]"):
l_x_ = L_x_ l_x_ = L_x_
wrap_body_0 = self.wrap_body_0 wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, l_x_); wrap_body_0 = s0 = l_x_ = None wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, l_x_); wrap_body_0 = s77 = l_x_ = None
getitem: "f32[s0]" = wrap[0]; wrap = None getitem: "f32[s77]" = wrap[0]; wrap = None
return (getitem,) return (getitem,)
class wrap_body_0(torch.nn.Module): class wrap_body_0(torch.nn.Module):
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0, 1]"): def forward(self, s77: "Sym(s77)", l_x_: "f32[s77, 1]"):
view: "f32[s0]" = l_x_.view(s0); l_x_ = s0 = None view: "f32[s77]" = l_x_.view(s77); l_x_ = s77 = None
add: "f32[s0]" = view + 0.5; view = None add: "f32[s77]" = view + 0.5; view = None
return (add,) return (add,)
""", """,
) )
@ -606,27 +606,27 @@ class GraphModule(torch.nn.Module):
out_graph, out_graph,
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0]"): def forward(self, s77: "Sym(s77)", L_x_: "f32[s77]"):
l_x_ = L_x_ l_x_ = L_x_
sum_1: "f32[]" = l_x_.sum() sum_1: "f32[]" = l_x_.sum()
item: "Sym(zuf0)" = sum_1.item(); sum_1 = None item: "Sym(zuf0)" = sum_1.item(); sum_1 = None
wrap_body_1 = self.wrap_body_1 wrap_body_1 = self.wrap_body_1
wrap = torch.ops.higher_order.wrap(wrap_body_1, s0, l_x_, item); wrap_body_1 = s0 = l_x_ = item = None wrap = torch.ops.higher_order.wrap(wrap_body_1, s77, l_x_, item); wrap_body_1 = s77 = l_x_ = item = None
getitem: "f32[s0]" = wrap[0]; wrap = None getitem: "f32[s77]" = wrap[0]; wrap = None
return (getitem,) return (getitem,)
class wrap_body_1(torch.nn.Module): class wrap_body_1(torch.nn.Module):
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", item: "Sym(zuf0)"): def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", item: "Sym(zuf0)"):
wrap_body_0 = self.wrap_body_0 wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, l_x_, item); wrap_body_0 = s0 = l_x_ = item = None wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, l_x_, item); wrap_body_0 = s77 = l_x_ = item = None
getitem: "f32[s0]" = wrap[0]; wrap = None getitem: "f32[s77]" = wrap[0]; wrap = None
return (getitem,) return (getitem,)
class wrap_body_0(torch.nn.Module): class wrap_body_0(torch.nn.Module):
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", item: "Sym(zuf0)"): def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", item: "Sym(zuf0)"):
add: "f32[s0]" = l_x_ + item; l_x_ = item = None add: "f32[s77]" = l_x_ + item; l_x_ = item = None
return (add,) return (add,)
""", """,
) )
@ -692,7 +692,7 @@ class GraphModule(torch.nn.Module):
out_graph, out_graph,
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0]"): def forward(self, s77: "Sym(s77)", L_x_: "f32[s77]"):
l_x_ = L_x_ l_x_ = L_x_
c: "i64[u0, 1]" = l_x_.nonzero() c: "i64[u0, 1]" = l_x_.nonzero()
@ -704,22 +704,22 @@ class GraphModule(torch.nn.Module):
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
wrap_body_1 = self.wrap_body_1 wrap_body_1 = self.wrap_body_1
wrap = torch.ops.higher_order.wrap(wrap_body_1, s0, l_x_, sym_size_int_1, c); wrap_body_1 = s0 = l_x_ = sym_size_int_1 = c = None wrap = torch.ops.higher_order.wrap(wrap_body_1, s77, l_x_, sym_size_int_1, c); wrap_body_1 = s77 = l_x_ = sym_size_int_1 = c = None
getitem: "f32[s0]" = wrap[0] getitem: "f32[s77]" = wrap[0]
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (getitem, getitem_1) return (getitem, getitem_1)
class wrap_body_1(torch.nn.Module): class wrap_body_1(torch.nn.Module):
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", u0: "Sym(u0)", c: "i64[u0, 1]"): def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", u0: "Sym(u0)", c: "i64[u0, 1]"):
wrap_body_0 = self.wrap_body_0 wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, l_x_, u0, c); wrap_body_0 = s0 = l_x_ = u0 = c = None wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, l_x_, u0, c); wrap_body_0 = s77 = l_x_ = u0 = c = None
child: "f32[s0]" = wrap[0] child: "f32[s77]" = wrap[0]
child_1: "f32[u0, 1]" = wrap[1]; wrap = None child_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (child, child_1) return (child, child_1)
class wrap_body_0(torch.nn.Module): class wrap_body_0(torch.nn.Module):
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", u0: "Sym(u0)", c: "i64[u0, 1]"): def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", u0: "Sym(u0)", c: "i64[u0, 1]"):
child: "f32[s0]" = l_x_.sin(); l_x_ = None child: "f32[s77]" = l_x_.sin(); l_x_ = None
child_1: "f32[u0, 1]" = c.sin(); c = None child_1: "f32[u0, 1]" = c.sin(); c = None
return (child, child_1) return (child, child_1)
""", """,
@ -994,25 +994,25 @@ class GraphModule(torch.nn.Module):
out_graph, out_graph,
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_x_: "f32[s0, s1]", s2: "Sym(s2)", L_y_: "f32[s1, s2]"): def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]", s94: "Sym(s94)", L_y_: "f32[s27, s94]"):
l_x_ = L_x_ l_x_ = L_x_
l_y_ = L_y_ l_y_ = L_y_
wrap_body_1 = self.wrap_body_1 wrap_body_1 = self.wrap_body_1
wrap = torch.ops.higher_order.wrap(wrap_body_1, s0, s1, l_x_, s2, l_y_); wrap_body_1 = s0 = s1 = l_x_ = s2 = l_y_ = None wrap = torch.ops.higher_order.wrap(wrap_body_1, s77, s27, l_x_, s94, l_y_); wrap_body_1 = s77 = s27 = l_x_ = s94 = l_y_ = None
getitem: "f32[s0, s2]" = wrap[0]; wrap = None getitem: "f32[s77, s94]" = wrap[0]; wrap = None
return (getitem,) return (getitem,)
class wrap_body_1(torch.nn.Module): class wrap_body_1(torch.nn.Module):
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", l_x_: "f32[s0, s1]", s2: "Sym(s2)", l_y_: "f32[s1, s2]"): def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", l_x_: "f32[s77, s27]", s94: "Sym(s94)", l_y_: "f32[s27, s94]"):
wrap_body_0 = self.wrap_body_0 wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, s1, l_x_, s2, l_y_); wrap_body_0 = s0 = s1 = l_x_ = s2 = l_y_ = None wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, s27, l_x_, s94, l_y_); wrap_body_0 = s77 = s27 = l_x_ = s94 = l_y_ = None
getitem: "f32[s0, s2]" = wrap[0]; wrap = None getitem: "f32[s77, s94]" = wrap[0]; wrap = None
return (getitem,) return (getitem,)
class wrap_body_0(torch.nn.Module): class wrap_body_0(torch.nn.Module):
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", l_x_: "f32[s0, s1]", s2: "Sym(s2)", l_y_: "f32[s1, s2]"): def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", l_x_: "f32[s77, s27]", s94: "Sym(s94)", l_y_: "f32[s27, s94]"):
matmul: "f32[s0, s2]" = l_x_ @ l_y_; l_x_ = l_y_ = None matmul: "f32[s77, s94]" = l_x_ @ l_y_; l_x_ = l_y_ = None
return (matmul,) return (matmul,)
""", """,
) )

View File

@ -10382,19 +10382,22 @@ ShapeEnv not equal: field values don't match:
> Left: {x.size()[0]: x.size()[0], x.size()[1]: x.size()[1], x.storage_offset(): x.storage_offset(), x.stride()[0]: x.stride()[0], x.stride()[1]: x.stride()[1]} > Left: {x.size()[0]: x.size()[0], x.size()[1]: x.size()[1], x.storage_offset(): x.storage_offset(), x.stride()[0]: x.stride()[0], x.stride()[1]: x.stride()[1]}
> Right: {} > Right: {}
==> source_to_var: values don't match. ==> source_to_var: values don't match.
> Left: {x.size()[0]: s0, x.size()[1]: s1} > Left: {x.size()[0]: s93, x.size()[1]: s44}
> Right: {}
==> unique_ids: values don't match.
> Left: {44, 93}
> Right: {} > Right: {}
==> val_to_var: values don't match. ==> val_to_var: values don't match.
> Left: {0: 0, 1: 1, 2: s1, 3: s0} > Left: {0: 0, 1: 1, 2: s44, 3: s93}
> Right: {0: 0, 1: 1} > Right: {0: 0, 1: 1}
==> var_to_range: values don't match. ==> var_to_range: values don't match.
> Left: {s0: VR[2, int_oo], s1: VR[2, int_oo]} > Left: {s44: VR[2, int_oo], s93: VR[2, int_oo]}
> Right: {} > Right: {}
==> var_to_sources: values don't match. ==> var_to_sources: values don't match.
> Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=1)]} > Left: {s44: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=1)], s93: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=0)]}
> Right: {} > Right: {}
==> var_to_val: values don't match. ==> var_to_val: values don't match.
> Left: {s0: 3, s1: 2} > Left: {s44: 2, s93: 3}
> Right: {} > Right: {}
""", """,
) )
@ -10453,13 +10456,13 @@ ShapeEnv not equal: field values don't match:
ShapeEnv not equal: field values don't match: ShapeEnv not equal: field values don't match:
==> axioms: values don't match. ==> axioms: values don't match.
> Left: {(Mod(s0, 3)) < 0: False, (Mod(s0, 3)) <= 0: True, 0 < (Mod(s0, 3)): False, 0 <= (Mod(s0, 3)): True, Eq(0, Mod(s0, 3)): True, Eq(Mod(s0, 3), 0): True, Ne(0, Mod(s0, 3)): False, Ne(Mod(s0, 3), 0): False} > Left: {(Mod(s93, 3)) < 0: False, (Mod(s93, 3)) <= 0: True, 0 < (Mod(s93, 3)): False, 0 <= (Mod(s93, 3)): True, Eq(0, Mod(s93, 3)): True, Eq(Mod(s93, 3), 0): True, Ne(0, Mod(s93, 3)): False, Ne(Mod(s93, 3), 0): False}
> Right: {} > Right: {}
==> divisible: values don't match. ==> divisible: values don't match.
> Left: {Mod(s0, 3)} > Left: {Mod(s93, 3)}
> Right: {} > Right: {}
==> guards: values don't match. ==> guards: values don't match.
> Left: [Eq(Mod(s0, 3), 0)] > Left: [Eq(Mod(s93, 3), 0)]
> Right: [] > Right: []
==> name_to_node: values don't match. ==> name_to_node: values don't match.
> Left: {_assert, eq, mod, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} > Left: {_assert, eq, mod, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
@ -10496,17 +10499,17 @@ ShapeEnv not equal: field values don't match:
> Left: {False: False, True: True} > Left: {False: False, True: True}
> Right: {} > Right: {}
==> guards: values don't match. ==> guards: values don't match.
> Left: [Eq(s0, 3)] > Left: [Eq(s93, 3)]
> Right: [] > Right: []
==> name_to_node: values don't match. ==> name_to_node: values don't match.
> Left: {_assert, eq, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} > Left: {_assert, eq, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
> Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
==> replacements: values don't match. ==> replacements: values don't match.
> Left: {s0: 3} > Left: {s93: 3}
> Right: {} > Right: {}
==> var_to_range: values don't match. ==> var_to_range: values don't match.
> Left: {s0: VR[3, 3], s1: VR[2, int_oo]} > Left: {s44: VR[2, int_oo], s93: VR[3, 3]}
> Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]} > Right: {s44: VR[2, int_oo], s93: VR[2, int_oo]}
""", """,
) )
self._replay_and_check(main) self._replay_and_check(main)
@ -10537,17 +10540,17 @@ ShapeEnv not equal: field values don't match:
ShapeEnv not equal: field values don't match: ShapeEnv not equal: field values don't match:
==> axioms: values don't match. ==> axioms: values don't match.
> Left: {3 <= s0: True, s0 < 3: False} > Left: {3 <= s93: True, s93 < 3: False}
> Right: {} > Right: {}
==> guards: values don't match. ==> guards: values don't match.
> Left: [s0 >= 3] > Left: [s93 >= 3]
> Right: [] > Right: []
==> name_to_node: values don't match. ==> name_to_node: values don't match.
> Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} > Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
> Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
==> var_to_range: values don't match. ==> var_to_range: values don't match.
> Left: {s0: VR[3, int_oo], s1: VR[2, int_oo]} > Left: {s44: VR[2, int_oo], s93: VR[3, int_oo]}
> Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]} > Right: {s44: VR[2, int_oo], s93: VR[2, int_oo]}
""", """,
) )
self._replay_and_check(main) self._replay_and_check(main)

View File

@ -4768,7 +4768,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
self.assertExpectedInline( self.assertExpectedInline(
str(graph.code).strip(), str(graph.code).strip(),
"""\ """\
def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor): def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
l_x_ = L_x_ l_x_ = L_x_
getitem_2 = l_x_[0] getitem_2 = l_x_[0]
sum_1 = getitem_2.sum(); getitem_2 = None sum_1 = getitem_2.sum(); getitem_2 = None

View File

@ -339,7 +339,7 @@ class StructuredTraceTest(TestCase):
{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"create_symbol": {"symbol": "s0", "val": "1", "vr": "[-int_oo, int_oo]", "source": "L['y']", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"create_symbol": {"symbol": "s48", "val": "1", "vr": "[-int_oo, int_oo]", "source": "L['y']", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"dynamo_output_graph": {"sizes": {"l_x_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"dynamo_output_graph": {"sizes": {"l_x_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"inductor_pre_grad_graph": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"inductor_pre_grad_graph": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "before_recompile_pre_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_recompile_pre_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
@ -767,15 +767,15 @@ class StructuredTraceTest(TestCase):
{"describe_storage": {"id": 0, "describer_id": "ID", "size": 200}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"describe_storage": {"id": 0, "describer_id": "ID", "size": 200}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [5, 10], "is_leaf": true, "stride": [10, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [5, 10], "is_leaf": true, "stride": [10, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"create_symbol": {"symbol": "s0", "val": "5", "vr": "[2, int_oo]", "source": "L['a'].size()[0]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"create_symbol": {"symbol": "s97", "val": "5", "vr": "[2, int_oo]", "source": "L['a'].size()[0]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"create_symbol": {"symbol": "s1", "val": "10", "vr": "[2, int_oo]", "source": "L['a'].size()[1]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"create_symbol": {"symbol": "s98", "val": "10", "vr": "[2, int_oo]", "source": "L['a'].size()[1]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_storage": {"id": 1, "describer_id": "ID", "size": 600}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"describe_storage": {"id": 1, "describer_id": "ID", "size": 600}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_tensor": {"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [10, 15], "is_leaf": true, "stride": [15, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"describe_tensor": {"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [10, 15], "is_leaf": true, "stride": [15, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['b']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 1, "source": "L['b']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"create_symbol": {"symbol": "s2", "val": "10", "vr": "[2, int_oo]", "source": "L['b'].size()[0]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"create_symbol": {"symbol": "s52", "val": "10", "vr": "[2, int_oo]", "source": "L['b'].size()[0]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"create_symbol": {"symbol": "s3", "val": "15", "vr": "[2, int_oo]", "source": "L['b'].size()[1]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"create_symbol": {"symbol": "s20", "val": "15", "vr": "[2, int_oo]", "source": "L['b'].size()[1]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"guard_added_fast": {"expr": "Eq(s1, s2)", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"guard_added_fast": {"expr": "Eq(s98, s52)", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"dynamo_output_graph": {"sizes": {"l_a_": ["s0", "s1"], "l_b_": ["s1", "s3"], "matmul": ["s0", "s3"]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"dynamo_output_graph": {"sizes": {"l_a_": ["s97", "s52"], "l_b_": ["s52", "s20"], "matmul": ["s97", "s20"]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
""", # noqa: B950 """, # noqa: B950

View File

@ -1373,21 +1373,21 @@ class GraphModule(torch.nn.Module):
# During fakeifying, we end up allocating a separate symint # During fakeifying, we end up allocating a separate symint
# for the outer and inner tensor (in this test, s0 is unused). # for the outer and inner tensor (in this test, s0 is unused).
expected_var_to_val = { expected_var_to_val = {
"s0": 8, "s50": 4,
"s1": 4, "s77": 8,
} }
expected_var_to_sources = { expected_var_to_sources = {
"s0": "L['x'].size()[0]", "s50": "L['x'].inner_elem.size()[0]",
"s1": "L['x'].inner_elem.size()[0]", "s77": "L['x'].size()[0]",
} }
self.assertEqual(curr_var_to_val, expected_var_to_val) self.assertEqual(curr_var_to_val, expected_var_to_val)
self.assertEqual(curr_var_to_sources, expected_var_to_sources) self.assertEqual(curr_var_to_sources, expected_var_to_sources)
self.assertExpectedInline( self.assertExpectedInline(
"\n".join(guards), "\n".join(guards),
"""\ """\
Eq(2*s1, s0) Eq(2*s50, s77)
2*s1 < 13 2*s50 < 13
s1 > 3""", s50 > 3""",
) )
def test_wrapper_subclass_with_same_sized_inner_tensor(self): def test_wrapper_subclass_with_same_sized_inner_tensor(self):
@ -1976,9 +1976,9 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)), normalize_gm(fw[0].print_readable(print_output=False)),
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"): def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
mul: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None mul: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None
mul_3: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None mul_3: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None
return (mul, mul_3, primals_5, primals_7, primals_7, primals_1, primals_5, primals_7) return (mul, mul_3, primals_5, primals_7, primals_7, primals_1, primals_5, primals_7)
""", # noqa: B950 """, # noqa: B950
) )
@ -1987,9 +1987,9 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)), normalize_gm(bw[0].print_readable(print_output=False)),
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s0)", primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0, s1]", tangents_2: "f32[s0, s1]"): def forward(self, primals_1: "Sym(s47)", primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s47, s16]", tangents_2: "f32[s47, s16]"):
mul_8: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = None mul_8: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = None
mul_9: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_2, primals_1); tangents_2 = primals_1 = None mul_9: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(tangents_2, primals_1); tangents_2 = primals_1 = None
return (None, None, mul_8, mul_9, primals_5, primals_7, primals_7) return (None, None, mul_8, mul_9, primals_5, primals_7, primals_7)
""", # noqa: B950 """, # noqa: B950
) )
@ -2009,12 +2009,12 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)), normalize_gm(fw[0].print_readable(print_output=False)),
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"): def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_3); primals_3 = None clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
view: "f32[s1, s0]" = torch.ops.aten.view.default(clone, [primals_2, primals_1]); clone = None view: "f32[s16, s47]" = torch.ops.aten.view.default(clone, [primals_2, primals_1]); clone = None
view_1: "f32[s1, s0]" = torch.ops.aten.view.default(clone_1, [primals_2, primals_1]); clone_1 = primals_1 = None view_1: "f32[s16, s47]" = torch.ops.aten.view.default(clone_1, [primals_2, primals_1]); clone_1 = primals_1 = None
return (view, view_1, primals_2, primals_5, primals_5, primals_5, primals_7) return (view, view_1, primals_2, primals_5, primals_5, primals_5, primals_7)
""", # noqa: B950 """, # noqa: B950
) )
@ -2023,9 +2023,9 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)), normalize_gm(bw[0].print_readable(print_output=False)),
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s1, s0]", tangents_2: "f32[s1, s0]"): def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s16, s47]", tangents_2: "f32[s16, s47]"):
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
return (None, None, view_2, view_3, primals_5, primals_7, primals_7) return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
""", # noqa: B950 """, # noqa: B950
) )
@ -2047,15 +2047,15 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)), normalize_gm(fw[0].print_readable(print_output=False)),
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"): def forward(self, primals_1: "Sym(s97)", primals_2: "Sym(s98)", primals_3: "f32[s97, s98]", primals_4: "f32[s97, s98]", primals_5: "Sym(s97)", primals_6: "Sym(s98)", primals_7: "Sym(s98)"):
mul: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None mul: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None
mul_3: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None mul_3: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None
mul_8: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul, primals_2); mul = None mul_8: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul, primals_2); mul = None
mul_11: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_3, primals_2); mul_3 = None mul_11: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_3, primals_2); mul_3 = None
mul_16: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_8, primals_1); mul_8 = None mul_16: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_8, primals_1); mul_8 = None
mul_19: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_11, primals_1); mul_11 = None mul_19: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_11, primals_1); mul_11 = None
mul_24: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_16, primals_2); mul_16 = None mul_24: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_16, primals_2); mul_16 = None
mul_27: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_19, primals_2); mul_19 = None mul_27: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_19, primals_2); mul_19 = None
return (mul_24, mul_27, primals_5, primals_7, primals_7, primals_1, primals_2, primals_5, primals_7) return (mul_24, mul_27, primals_5, primals_7, primals_7, primals_1, primals_2, primals_5, primals_7)
""", # noqa: B950 """, # noqa: B950
) )
@ -2064,15 +2064,15 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)), normalize_gm(bw[0].print_readable(print_output=False)),
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0, s1]", tangents_2: "f32[s0, s1]"): def forward(self, primals_1: "Sym(s97)", primals_2: "Sym(s98)", primals_5: "Sym(s97)", primals_7: "Sym(s98)", tangents_1: "f32[s97, s98]", tangents_2: "f32[s97, s98]"):
mul_32: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_1, primals_2); tangents_1 = None mul_32: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(tangents_1, primals_2); tangents_1 = None
mul_33: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_2, primals_2); tangents_2 = None mul_33: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(tangents_2, primals_2); tangents_2 = None
mul_34: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_32, primals_1); mul_32 = None mul_34: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_32, primals_1); mul_32 = None
mul_35: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_33, primals_1); mul_33 = None mul_35: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_33, primals_1); mul_33 = None
mul_36: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_34, primals_2); mul_34 = None mul_36: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_34, primals_2); mul_34 = None
mul_37: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_35, primals_2); mul_35 = primals_2 = None mul_37: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_35, primals_2); mul_35 = primals_2 = None
mul_38: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_36, primals_1); mul_36 = None mul_38: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_36, primals_1); mul_36 = None
mul_39: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_37, primals_1); mul_37 = primals_1 = None mul_39: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_37, primals_1); mul_37 = primals_1 = None
return (None, None, mul_38, mul_39, primals_5, primals_7, primals_7) return (None, None, mul_38, mul_39, primals_5, primals_7, primals_7)
""", # noqa: B950 """, # noqa: B950
) )
@ -2092,12 +2092,12 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)), normalize_gm(fw[0].print_readable(print_output=False)),
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"): def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_3); primals_3 = None clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
view: "f32[s0, s1]" = torch.ops.aten.view.default(clone, [primals_1, primals_2]); clone = None view: "f32[s47, s16]" = torch.ops.aten.view.default(clone, [primals_1, primals_2]); clone = None
view_1: "f32[s0, s1]" = torch.ops.aten.view.default(clone_1, [primals_1, primals_2]); clone_1 = primals_1 = primals_2 = None view_1: "f32[s47, s16]" = torch.ops.aten.view.default(clone_1, [primals_1, primals_2]); clone_1 = primals_1 = primals_2 = None
return (view, view_1, primals_5, primals_7, primals_7, primals_5, primals_7) return (view, view_1, primals_5, primals_7, primals_7, primals_5, primals_7)
""", # noqa: B950 """, # noqa: B950
) )
@ -2106,9 +2106,9 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)), normalize_gm(bw[0].print_readable(print_output=False)),
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0, s1]", tangents_2: "f32[s0, s1]"): def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s47, s16]", tangents_2: "f32[s47, s16]"):
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
return (None, None, view_2, view_3, primals_5, primals_7, primals_7) return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
""", # noqa: B950 """, # noqa: B950
) )
@ -2128,13 +2128,13 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)), normalize_gm(fw[0].print_readable(print_output=False)),
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"): def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_3); primals_3 = None clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
mul_6: "Sym(s0*s1)" = primals_1 * primals_2; primals_1 = primals_2 = None mul_6: "Sym(s16*s47)" = primals_1 * primals_2; primals_1 = primals_2 = None
view: "f32[s0*s1]" = torch.ops.aten.view.default(clone, [mul_6]); clone = None view: "f32[s16*s47]" = torch.ops.aten.view.default(clone, [mul_6]); clone = None
view_1: "f32[s0*s1]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None view_1: "f32[s16*s47]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None
return (view, view_1, mul_6, primals_5, primals_7) return (view, view_1, mul_6, primals_5, primals_7)
""", # noqa: B950 """, # noqa: B950
) )
@ -2143,9 +2143,9 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)), normalize_gm(bw[0].print_readable(print_output=False)),
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0*s1]", tangents_2: "f32[s0*s1]"): def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s16*s47]", tangents_2: "f32[s16*s47]"):
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
return (None, None, view_2, view_3, primals_5, primals_7, primals_7) return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
""", # noqa: B950 """, # noqa: B950
) )
@ -2165,13 +2165,13 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)), normalize_gm(fw[0].print_readable(print_output=False)),
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"): def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_3); primals_3 = None clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
mul_6: "Sym(s0*s1)" = primals_1 * primals_2; primals_1 = primals_2 = None mul_6: "Sym(s16*s47)" = primals_1 * primals_2; primals_1 = primals_2 = None
view: "f32[s0*s1]" = torch.ops.aten.view.default(clone, [mul_6]) view: "f32[s16*s47]" = torch.ops.aten.view.default(clone, [mul_6])
view_1: "f32[s0*s1]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None view_1: "f32[s16*s47]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None
return (clone, view, view_1, mul_6, primals_5, primals_7) return (clone, view, view_1, mul_6, primals_5, primals_7)
""", # noqa: B950 """, # noqa: B950
) )
@ -2180,9 +2180,9 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)), normalize_gm(bw[0].print_readable(print_output=False)),
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0*s1]", tangents_2: "f32[s0*s1]"): def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s16*s47]", tangents_2: "f32[s16*s47]"):
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
return (None, None, view_2, view_3, primals_5, primals_7, primals_7) return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
""", # noqa: B950 """, # noqa: B950
) )
@ -2261,13 +2261,13 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[1].print_readable(print_output=False)), normalize_gm(fw[1].print_readable(print_output=False)),
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s0)", primals_2: "f32[3, s0]", primals_3: "f32[3, s0]", primals_4: "Sym(s0)", primals_5: "Sym(s0)"): def forward(self, primals_1: "Sym(s16)", primals_2: "f32[3, s16]", primals_3: "f32[3, s16]", primals_4: "Sym(s16)", primals_5: "Sym(s16)"):
clone: "f32[3, s0]" = torch.ops.aten.clone.default(primals_2); primals_2 = None clone: "f32[3, s16]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
clone_1: "f32[3, s0]" = torch.ops.aten.clone.default(primals_3); primals_3 = None clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
view: "f32[3*s0]" = torch.ops.aten.view.default(clone, [-1]) view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1])
sym_numel_default: "Sym(3*s0)" = torch.ops.aten.sym_numel.default(clone) sym_numel_default: "Sym(3*s16)" = torch.ops.aten.sym_numel.default(clone)
view_1: "f32[3*s0]" = torch.ops.aten.view.default(clone_1, [-1]) view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1])
return (clone, view, view_1, sym_numel_default, clone_1, primals_5) return (clone, view, view_1, sym_numel_default, clone_1, primals_5)
""", # noqa: B950 """, # noqa: B950
) )
@ -2287,9 +2287,9 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[1].print_readable(print_output=False)), normalize_gm(bw[1].print_readable(print_output=False)),
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, primals_5: "Sym(s0)", tangents_1: "f32[3*s0]", tangents_2: "f32[3*s0]"): def forward(self, primals_5: "Sym(s16)", tangents_1: "f32[3*s16]", tangents_2: "f32[3*s16]"):
view_2: "f32[3, s0]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None view_2: "f32[3, s16]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None
view_3: "f32[3, s0]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None view_3: "f32[3, s16]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None
return (None, view_2, view_3, primals_5, primals_5) return (None, view_2, view_3, primals_5, primals_5)
""", # noqa: B950 """, # noqa: B950
) )
@ -2317,13 +2317,13 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)), normalize_gm(fw[0].print_readable(print_output=False)),
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s0)", primals_2: "f32[3, s0]", primals_3: "f32[3, s0]", primals_4: "Sym(s0)", primals_5: "Sym(s0)"): def forward(self, primals_1: "Sym(s16)", primals_2: "f32[3, s16]", primals_3: "f32[3, s16]", primals_4: "Sym(s16)", primals_5: "Sym(s16)"):
clone: "f32[3, s0]" = torch.ops.aten.clone.default(primals_2); primals_2 = None clone: "f32[3, s16]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
clone_1: "f32[3, s0]" = torch.ops.aten.clone.default(primals_3); primals_3 = None clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
view: "f32[3*s0]" = torch.ops.aten.view.default(clone, [-1]) view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1])
sym_numel_default: "Sym(3*s0)" = torch.ops.aten.sym_numel.default(clone) sym_numel_default: "Sym(3*s16)" = torch.ops.aten.sym_numel.default(clone)
view_1: "f32[3*s0]" = torch.ops.aten.view.default(clone_1, [-1]) view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1])
return (clone, view, view_1, sym_numel_default, clone_1, primals_5) return (clone, view, view_1, sym_numel_default, clone_1, primals_5)
""", # noqa: B950 """, # noqa: B950
) )
@ -2332,9 +2332,9 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)), normalize_gm(bw[0].print_readable(print_output=False)),
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, primals_5: "Sym(s0)", tangents_1: "f32[3*s0]", tangents_2: "f32[3*s0]"): def forward(self, primals_5: "Sym(s16)", tangents_1: "f32[3*s16]", tangents_2: "f32[3*s16]"):
view_2: "f32[3, s0]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None view_2: "f32[3, s16]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None
view_3: "f32[3, s0]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None view_3: "f32[3, s16]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None
return (None, view_2, view_3, primals_5, primals_5) return (None, view_2, view_3, primals_5, primals_5)
""", # noqa: B950 """, # noqa: B950
) )
@ -2501,10 +2501,10 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)), normalize_gm(fw[0].print_readable(print_output=False)),
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s2)", primals_2: "Sym(s3)", primals_3: "Sym(s1)", primals_4: "f64[s0, s1]", primals_5: "i64[s2 + 1]", primals_6: "f32[s6, 0]", primals_7: "f32[s7, 0]", primals_8: "Sym(s2)", primals_9: "Sym(s1)", primals_10: "Sym(s1)"): def forward(self, primals_1: "Sym(s51)", primals_2: "Sym(s71)", primals_3: "Sym(s55)", primals_4: "f64[s64, s55]", primals_5: "i64[s51 + 1]", primals_6: "f32[s0, 0]", primals_7: "f32[s83, 0]", primals_8: "Sym(s51)", primals_9: "Sym(s55)", primals_10: "Sym(s55)"):
clone: "f64[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None clone: "f64[s64, s55]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
mul: "f64[s0, s1]" = torch.ops.aten.mul.Tensor(clone, primals_1); clone = None mul: "f64[s64, s55]" = torch.ops.aten.mul.Tensor(clone, primals_1); clone = None
return (mul, primals_5, primals_6, primals_7, primals_8, primals_10, primals_10, primals_1, primals_8, primals_10) return (mul, primals_5, primals_6, primals_7, primals_8, primals_10, primals_10, primals_1, primals_8, primals_10)
""", # noqa: B950 """, # noqa: B950
) )
@ -2513,8 +2513,8 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)), normalize_gm(bw[0].print_readable(print_output=False)),
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s2)", primals_8: "Sym(s2)", primals_10: "Sym(s1)", tangents_1: "f64[s0, s1]", tangents_2: "i64[s2 + 1]", tangents_3: "f32[s6, 0]", tangents_4: "f32[s7, 0]"): def forward(self, primals_1: "Sym(s51)", primals_8: "Sym(s51)", primals_10: "Sym(s55)", tangents_1: "f64[s64, s55]", tangents_2: "i64[s51 + 1]", tangents_3: "f32[s0, 0]", tangents_4: "f32[s83, 0]"):
mul_1: "f64[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = primals_1 = None mul_1: "f64[s64, s55]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = primals_1 = None
return (None, None, None, mul_1, tangents_2, tangents_3, tangents_4, primals_8, primals_10, primals_10) return (None, None, None, mul_1, tangents_2, tangents_3, tangents_4, primals_8, primals_10, primals_10)
""", # noqa: B950 """, # noqa: B950
) )
@ -2534,11 +2534,11 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)), normalize_gm(fw[0].print_readable(print_output=False)),
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s2)", primals_2: "Sym(s3)", primals_3: "Sym(s1)", primals_4: "f64[s0, s1]", primals_5: "i64[s2 + 1]", primals_6: "f32[s6, 0]", primals_7: "f32[s7, 0]", primals_8: "Sym(s2)", primals_9: "Sym(s1)", primals_10: "Sym(s1)"): def forward(self, primals_1: "Sym(s51)", primals_2: "Sym(s71)", primals_3: "Sym(s55)", primals_4: "f64[s64, s55]", primals_5: "i64[s51 + 1]", primals_6: "f32[s0, 0]", primals_7: "f32[s83, 0]", primals_8: "Sym(s51)", primals_9: "Sym(s55)", primals_10: "Sym(s55)"):
clone: "f64[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None clone: "f64[s64, s55]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
cat: "f64[s0, 2*s1]" = torch.ops.aten.cat.default([clone, clone], 1); clone = None cat: "f64[s64, 2*s55]" = torch.ops.aten.cat.default([clone, clone], 1); clone = None
add_2: "Sym(2*s1)" = primals_10 + primals_10 add_2: "Sym(2*s55)" = primals_10 + primals_10
return (cat, primals_5, primals_6, primals_7, primals_8, add_2, add_2, primals_8, primals_10, add_2) return (cat, primals_5, primals_6, primals_7, primals_8, add_2, add_2, primals_8, primals_10, add_2)
""", # noqa: B950 """, # noqa: B950
) )
@ -2547,11 +2547,11 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)), normalize_gm(bw[0].print_readable(print_output=False)),
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, primals_8: "Sym(s2)", primals_10: "Sym(s1)", add_2: "Sym(2*s1)", tangents_1: "f64[s0, 2*s1]", tangents_2: "i64[s2 + 1]", tangents_3: "f32[s6, 0]", tangents_4: "f32[s7, 0]"): def forward(self, primals_8: "Sym(s51)", primals_10: "Sym(s55)", add_2: "Sym(2*s55)", tangents_1: "f64[s64, 2*s55]", tangents_2: "i64[s51 + 1]", tangents_3: "f32[s0, 0]", tangents_4: "f32[s83, 0]"):
slice_1: "f64[s0, s1]" = torch.ops.aten.slice.Tensor(tangents_1, 1, 0, primals_10) slice_1: "f64[s64, s55]" = torch.ops.aten.slice.Tensor(tangents_1, 1, 0, primals_10)
slice_2: "f64[s0, s1]" = torch.ops.aten.slice.Tensor(tangents_1, 1, primals_10, add_2); tangents_1 = add_2 = None slice_2: "f64[s64, s55]" = torch.ops.aten.slice.Tensor(tangents_1, 1, primals_10, add_2); tangents_1 = add_2 = None
add_4: "f64[s0, s1]" = torch.ops.aten.add.Tensor(slice_1, slice_2); slice_1 = slice_2 = None add_4: "f64[s64, s55]" = torch.ops.aten.add.Tensor(slice_1, slice_2); slice_1 = slice_2 = None
return (None, None, None, add_4, tangents_2, tangents_3, tangents_4, primals_8, primals_10, primals_10) return (None, None, None, add_4, tangents_2, tangents_3, tangents_4, primals_8, primals_10, primals_10)
""", # noqa: B950 """, # noqa: B950
) )
@ -2580,7 +2580,7 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)), normalize_gm(fw[0].print_readable(print_output=False)),
"""\ """\
class <lambda>(torch.nn.Module): class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "Sym(s3)", arg1_1: "Sym(s4)", arg2_1: "Sym(s2)", arg3_1: "f64[9, s2]", arg4_1: "i64[s3 + 1]", arg5_1: "f32[s7, 0]", arg6_1: "f32[s8, 0]", arg7_1: "Sym(s3)", arg8_1: "Sym(s2)", arg9_1: "Sym(s2)"): def forward(self, arg0_1: "Sym(s51)", arg1_1: "Sym(s71)", arg2_1: "Sym(s55)", arg3_1: "f64[9, s55]", arg4_1: "i64[s51 + 1]", arg5_1: "f32[s0, 0]", arg6_1: "f32[s83, 0]", arg7_1: "Sym(s51)", arg8_1: "Sym(s55)", arg9_1: "Sym(s55)"):
randn: "f64[2, 5]" = torch.ops.aten.randn.default([2, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False) randn: "f64[2, 5]" = torch.ops.aten.randn.default([2, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
randn_1: "f64[3, 5]" = torch.ops.aten.randn.default([3, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False) randn_1: "f64[3, 5]" = torch.ops.aten.randn.default([3, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
randn_2: "f64[4, 5]" = torch.ops.aten.randn.default([4, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False) randn_2: "f64[4, 5]" = torch.ops.aten.randn.default([4, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
@ -2594,13 +2594,13 @@ class <lambda>(torch.nn.Module):
zeros_1: "f32[2, 0]" = torch.ops.aten.zeros.default([2, 0], device = device(type='cpu'), pin_memory = False) zeros_1: "f32[2, 0]" = torch.ops.aten.zeros.default([2, 0], device = device(type='cpu'), pin_memory = False)
zeros_2: "f32[4, 0]" = torch.ops.aten.zeros.default([4, 0], device = device(type='cpu'), pin_memory = False) zeros_2: "f32[4, 0]" = torch.ops.aten.zeros.default([4, 0], device = device(type='cpu'), pin_memory = False)
cat_2: "f64[9, s2 + 5]" = torch.ops.aten.cat.default([cat, arg3_1], 1); cat = arg3_1 = None cat_2: "f64[9, s55 + 5]" = torch.ops.aten.cat.default([cat, arg3_1], 1); cat = arg3_1 = None
sin: "f64[9, s2 + 5]" = torch.ops.aten.sin.default(cat_2) sin: "f64[9, s55 + 5]" = torch.ops.aten.sin.default(cat_2)
mul: "f64[9, s2 + 5]" = torch.ops.aten.mul.Tensor(sin, 3); sin = None mul: "f64[9, s55 + 5]" = torch.ops.aten.mul.Tensor(sin, 3); sin = None
sym_size_int: "Sym(s2 + 5)" = torch.ops.aten.sym_size.int(cat_2, 1); cat_2 = None sym_size_int: "Sym(s55 + 5)" = torch.ops.aten.sym_size.int(cat_2, 1); cat_2 = None
sym_stride_int: "Sym(s2 + 5)" = torch.ops.aten.sym_stride.int(mul, 0) sym_stride_int: "Sym(s55 + 5)" = torch.ops.aten.sym_stride.int(mul, 0)
return (mul, cat_1, zeros_1, zeros_2, sym_size_int, sym_stride_int) return (mul, cat_1, zeros_1, zeros_2, sym_size_int, sym_stride_int)
""", # noqa: B950 """, # noqa: B950
) )
@ -2757,10 +2757,10 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase, NestedTensorTestCase):
norm_graph, norm_graph,
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, s1: "Sym(s1)", L_nt_: "f64[3, s1, 5]"): def forward(self, s71: "Sym(s71)", L_nt_: "f64[3, s71, 5]"):
l_nt_ = L_nt_ l_nt_ = L_nt_
add: "f64[3, s1, 5]" = l_nt_ + 2; l_nt_ = None add: "f64[3, s71, 5]" = l_nt_ + 2; l_nt_ = None
return (add,) return (add,)
""", # noqa: B950 """, # noqa: B950
) )
@ -3254,27 +3254,27 @@ class GraphModule(torch.nn.Module):
# varies based on the type of view # varies based on the type of view
guard_str = "\n".join(guards) guard_str = "\n".join(guards)
if nt_view_name == "subclass_dense": if nt_view_name == "subclass_dense":
self.assertExpectedInline(guard_str, """Eq(s3 - 1, s0)""") self.assertExpectedInline(guard_str, """Eq(s85 - 1, s77)""")
elif nt_view_name == "dense_subclass_dense_subclass": elif nt_view_name == "dense_subclass_dense_subclass":
self.assertExpectedInline( self.assertExpectedInline(
guard_str, guard_str,
"""\ """\
Eq(s5 - 1, s2) Eq(s85 - 1, s77)
Eq(s12 - 1, s7) Eq(s80 - 1, s78)
Eq(s11, s9)""", Eq(s72, s71)""",
) )
elif nt_view_name.startswith("base_is_nt_True"): elif nt_view_name.startswith("base_is_nt_True"):
self.assertExpectedInline( self.assertExpectedInline(
guard_str, guard_str,
"""Eq(s3 - 1, s0)""", """Eq(s17 - 1, s83)""",
) )
else: else:
self.assertExpectedInline( self.assertExpectedInline(
guard_str, guard_str,
"""\ """\
Eq(s4 - 1, s1) Eq(s85 - 1, s64)
Eq(s13 - 1, s8) Eq(s80 - 1, s77)
Eq(s12, s10)""", Eq(s72, s71)""",
) )
return gm return gm

View File

@ -1560,7 +1560,7 @@ graph():
) )
with self.assertRaisesRegex( with self.assertRaisesRegex(
error_type, error_type,
r"Real tensor propagation found an output size mismatch between fake shape s1 and real shape 4, " r"Real tensor propagation found an output size mismatch between fake shape s\d+ and real shape 4, "
r"at output\.size\(0\), for func: mylib.foo.default", r"at output\.size\(0\), for func: mylib.foo.default",
): ):
export( export(
@ -2848,7 +2848,7 @@ def forward(self, p_linear_weight, p_linear_bias, x):
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,
"Expected input.*shape.*= 9 to be " "Expected input.*shape.*= 9 to be "
"of the form 2\\*s1, where s1 is an integer", "of the form 2\\*s92, where s92 is an integer",
): ):
ep.module()(torch.randn(9)) ep.module()(torch.randn(9))
@ -3506,8 +3506,11 @@ def forward(self, x):
dynamic_shapes=({0: Dim("x")},), dynamic_shapes=({0: Dim("x")},),
) )
# Since symbol names are based on hash of source names, and these differ across inference and
# training, we do range comparisons instead.
self.assertEqual( self.assertEqual(
str(ep_for_training.range_constraints), str(ep_for_real.range_constraints) str(ep_for_training.range_constraints.values()),
str(ep_for_real.range_constraints.values()),
) )
def test_export_for_training_with_container_type(self): def test_export_for_training_with_container_type(self):
@ -4398,7 +4401,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
em.module()(torch.randn(4, 3)) em.module()(torch.randn(4, 3))
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,
r"Runtime assertion failed for expression Eq\(Mod\(s0\*s1, s0 \- 1\), 0\)", r"Runtime assertion failed for expression Eq\(Mod\(s27\*s77, s77 \- 1\), 0\)",
): ):
em.module()(torch.randn(4, 5)) em.module()(torch.randn(4, 5))
@ -4409,7 +4412,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
x = torch.randn(3, 5) x = torch.randn(3, 5)
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,
"Expected.*shape\\[1\\] = 5 to be of the form 2\\*s1, where s1 is an integer", "Expected.*shape\\[1\\] = 5 to be of the form 2\\*s33, where s33 is an integer",
): ):
em.module()(x) em.module()(x)
@ -4968,11 +4971,14 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
) )
self.assertEqual( self.assertEqual(
[ [
str(node.meta["val"].shape) # First dimension varies across strict and non-strict
# since the source names are different, resulting in
# different symbol names.
str(node.meta["val"].shape[1:])
for node in efoo.graph_module.graph.nodes for node in efoo.graph_module.graph.nodes
if node.op == "placeholder" if node.op == "placeholder"
], ],
["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"], ["torch.Size([2, 3])", "torch.Size([3, 4])"],
) )
@testing.expectedFailureCppSerDes @testing.expectedFailureCppSerDes
@ -5110,14 +5116,10 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
"y": (batch, size, size), "y": (batch, size, size),
}, },
) )
self.assertEqual(
[ for node in efoo.graph_module.graph.nodes:
str(node.meta["val"].shape) if node.op == "placeholder":
for node in efoo.graph_module.graph.nodes self.assertEqual(node.meta["val"].shape[1], node.meta["val"].shape[2])
if node.op == "placeholder"
],
["torch.Size([s0, s1, s1])", "torch.Size([s0, s1, s1])"],
)
self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape) self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape)
# pass dynamic shapes of inputs [multiple, mostly distinct] # pass dynamic shapes of inputs [multiple, mostly distinct]
@ -5128,13 +5130,14 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
inputs, inputs,
dynamic_shapes={"x": (batch, M, K), "y": (batch, K, N)}, dynamic_shapes={"x": (batch, M, K), "y": (batch, K, N)},
) )
placeholders = [
node.meta["val"].shape
for node in efoo.graph_module.graph.nodes
if node.op == "placeholder"
]
self.assertEqual( self.assertEqual(
[ placeholders[0][2],
str(node.meta["val"].shape) placeholders[1][1],
for node in efoo.graph_module.graph.nodes
if node.op == "placeholder"
],
["torch.Size([s0, s1, s2])", "torch.Size([s0, s2, s5])"],
) )
self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape) self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape)
@ -5151,11 +5154,14 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
) )
self.assertEqual( self.assertEqual(
[ [
str(node.meta["val"].shape) # First dimension varies across strict and non-strict
# since the source names are different, resulting in
# different symbol names.
str(node.meta["val"].shape[1:])
for node in efoo.graph_module.graph.nodes for node in efoo.graph_module.graph.nodes
if node.op == "placeholder" if node.op == "placeholder"
], ],
["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"], ["torch.Size([2, 3])", "torch.Size([3, 4])"],
) )
self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape) self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape)
@ -5172,11 +5178,14 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
) )
self.assertEqual( self.assertEqual(
[ [
str(node.meta["val"].shape) # First dimension varies across strict and non-strict
# since the source names are different, resulting in
# different symbol names.
str(node.meta["val"].shape[1:])
for node in efoo.graph_module.graph.nodes for node in efoo.graph_module.graph.nodes
if node.op == "placeholder" if node.op == "placeholder"
], ],
["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"], ["torch.Size([2, 3])", "torch.Size([3, 4])"],
) )
self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape) self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape)
@ -5486,7 +5495,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
if node.op == "placeholder" if node.op == "placeholder"
] ]
self.assertEqual(len(input_shapes), 9) self.assertEqual(len(input_shapes), 9)
self.assertTrue(all(shape == "torch.Size([s0])" for shape in input_shapes)) self.assertTrue(all(shape == "torch.Size([s3])" for shape in input_shapes))
def test_error_does_not_reference_eager_fallback(self): def test_error_does_not_reference_eager_fallback(self):
class Module(torch.nn.Module): class Module(torch.nn.Module):
@ -11165,7 +11174,7 @@ def forward(self, x, y):
self.assertEqual(out2.shape, torch.ones(11, 4, 3).shape) self.assertEqual(out2.shape, torch.ones(11, 4, 3).shape)
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,
r"Runtime assertion failed for expression Eq\(Mod\(s0\*s1, 4\*s0 \- 4\), 0\) on node 'eq.*'", r"Runtime assertion failed for expression Eq\(Mod\(s27\*s77, 4\*s77 \- 4\), 0\) on node 'eq.*'",
): ):
ep.module()(torch.randn(8, 8)) # fail ep.module()(torch.randn(8, 8)) # fail
@ -11197,7 +11206,7 @@ def forward(self, x, y):
self.assertEqual(out2.shape, torch.ones(40).shape) self.assertEqual(out2.shape, torch.ones(40).shape)
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,
r"Runtime assertion failed for expression Eq\(s0\*s1, s2\*s3\) on node 'eq.*'", r"Runtime assertion failed for expression Eq\((.*)\) on node '.*'",
): # fail only at runtime ): # fail only at runtime
ep.module()(torch.randn(5, 8), torch.randn(4, 5), torch.randn(30)) # fail ep.module()(torch.randn(5, 8), torch.randn(4, 5), torch.randn(30)) # fail
@ -11224,7 +11233,7 @@ def forward(self, x, y):
self.assertEqual(out1.shape, torch.ones(126).shape) self.assertEqual(out1.shape, torch.ones(126).shape)
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,
r"Runtime assertion failed for expression Eq\(s0\*s1\*s2, s3\) on node 'eq.*'", r"Runtime assertion failed for expression Eq\((.*)\) on node '.*'",
): # fail only at runtime ): # fail only at runtime
ep.module()(torch.randn(4, 3, 2), torch.randn(10)) # fail ep.module()(torch.randn(4, 3, 2), torch.randn(10)) # fail
@ -11305,12 +11314,12 @@ def forward(self, x, y):
) )
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,
r"Runtime assertion failed for expression Ne\(s0, 20\)", r"Runtime assertion failed for expression Ne\(s77, 20\)",
): ):
ep.module()(torch.randn(20, 20, 16)) ep.module()(torch.randn(20, 20, 16))
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,
r"Runtime assertion failed for expression Ne\(Mod\(s0, 20\), 0\)", r"Runtime assertion failed for expression Ne\(Mod\(s77, 20\), 0\)",
): ):
ep.module()(torch.randn(400, 20, 16)) ep.module()(torch.randn(400, 20, 16))
ep.module()(torch.randn(42, 20, 16)) ep.module()(torch.randn(42, 20, 16))
@ -11348,17 +11357,17 @@ def forward(self, x, y):
self.assertEqual(out1.shape, torch.ones(27).shape) self.assertEqual(out1.shape, torch.ones(27).shape)
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,
r"Runtime assertion failed for expression Ne\(s0, s1\)", r"Runtime assertion failed for expression Ne\(s77, s17\)",
): # fail only at runtime ): # fail only at runtime
ep.module()(torch.randn(4), torch.randn(4)) # fail ep.module()(torch.randn(4), torch.randn(4)) # fail
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,
r"Runtime assertion failed for expression Ne\(s0, s1\**3\)", r"Runtime assertion failed for expression Ne\(s77, s17\**3\)",
): ):
ep.module()(torch.randn(64), torch.randn(4)) # fail ep.module()(torch.randn(64), torch.randn(4)) # fail
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,
r"Runtime assertion failed for expression Eq\(s0\**2, 3\*s1\)", r"Runtime assertion failed for expression Eq\(s77\**2, 3\*s17\)",
): ):
ep.module()(torch.randn(10), torch.randn(9)) # fail ep.module()(torch.randn(10), torch.randn(9)) # fail

View File

@ -539,8 +539,12 @@ def forward(self, x):
ep.range_constraints[symint] = ValueRanges(lower=lower_range, upper=upper_range) ep.range_constraints[symint] = ValueRanges(lower=lower_range, upper=upper_range)
serialized = ExportedProgramSerializer().serialize(ep) serialized = ExportedProgramSerializer().serialize(ep)
self.assertEqual(serialized.exported_program.range_constraints["s0"].min_val, 2) self.assertEqual(
self.assertEqual(serialized.exported_program.range_constraints["s0"].max_val, 3) serialized.exported_program.range_constraints["s77"].min_val, 2
)
self.assertEqual(
serialized.exported_program.range_constraints["s77"].max_val, 3
)
def test_kwargs_default(self) -> None: def test_kwargs_default(self) -> None:
""" """

View File

@ -5941,8 +5941,8 @@ class TestAOTModuleSimplified(AOTTestCase):
self.assertExpectedInline( self.assertExpectedInline(
shape_env.format_guards(), shape_env.format_guards(),
"""\ """\
- Eq(s1, 20) - Eq(s49, 20)
- Eq(s2, 30)""", - Eq(s70, 30)""",
) )
assert torch.allclose(ref[0], res[0]) assert torch.allclose(ref[0], res[0])

View File

@ -4553,10 +4553,10 @@ def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_bo
gm.code.strip("\n"), gm.code.strip("\n"),
"""\ """\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1): def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
sym_size_int = torch.ops.aten.sym_size.int(arg2_1, 0) sym_size_int = torch.ops.aten.sym_size.int(arg3_1, 1)
sym_size_int_1 = torch.ops.aten.sym_size.int(arg2_1, 1) sym_size_int_1 = torch.ops.aten.sym_size.int(arg2_1, 1)
sym_size_int_2 = torch.ops.aten.sym_size.int(arg3_1, 0) sym_size_int_2 = torch.ops.aten.sym_size.int(arg2_1, 0)
sym_size_int_3 = torch.ops.aten.sym_size.int(arg3_1, 1) sym_size_int_3 = torch.ops.aten.sym_size.int(arg3_1, 0)
while_loop_cond_graph_0 = self.while_loop_cond_graph_0 while_loop_cond_graph_0 = self.while_loop_cond_graph_0
while_loop_body_graph_0 = self.while_loop_body_graph_0 while_loop_body_graph_0 = self.while_loop_body_graph_0
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), (sym_size_int, sym_size_int_1, sym_size_int_2, sym_size_int_3)); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = None while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), (sym_size_int, sym_size_int_1, sym_size_int_2, sym_size_int_3)); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = None
@ -4719,10 +4719,10 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1
def forward(self, a_1, b_1): def forward(self, a_1, b_1):
sum_1 = torch.ops.aten.sum.default(a_1) sum_1 = torch.ops.aten.sum.default(a_1)
gt = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None gt = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0) sym_size_int = torch.ops.aten.sym_size.int(a_1, 1)
sym_size_int_1 = torch.ops.aten.sym_size.int(a_1, 1) sym_size_int_1 = torch.ops.aten.sym_size.int(b_1, 0)
sym_size_int_2 = torch.ops.aten.sym_size.int(b_1, 0) sym_size_int_2 = torch.ops.aten.sym_size.int(b_1, 1)
sym_size_int_3 = torch.ops.aten.sym_size.int(b_1, 1) sym_size_int_3 = torch.ops.aten.sym_size.int(a_1, 0)
true_graph_0 = self.true_graph_0 true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0 false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (a_1, b_1, sym_size_int, sym_size_int_1, sym_size_int_2, sym_size_int_3)); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = None cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (a_1, b_1, sym_size_int, sym_size_int_1, sym_size_int_2, sym_size_int_3)); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = None
@ -5856,7 +5856,7 @@ def forward(self, x_1):
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1) sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
true_graph_0 = self.true_graph_0 true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0 false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int, sym_size_int_1)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int = sym_size_int_1 = None cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int_1, sym_size_int)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int_1 = sym_size_int = None
getitem = cond[0]; cond = None getitem = cond[0]; cond = None
return getitem""", # noqa: B950 return getitem""", # noqa: B950
) )
@ -5969,7 +5969,7 @@ def forward(self, x_1):
false_graph_0 = self.false_graph_0 false_graph_0 = self.false_graph_0
_tensor_constant0 = self._tensor_constant0 _tensor_constant0 = self._tensor_constant0
_tensor_constant1 = self._tensor_constant1 _tensor_constant1 = self._tensor_constant1
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, _tensor_constant0, sym_size_int, sym_size_int_1, _tensor_constant1)); eq = true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = sym_size_int = sym_size_int_1 = _tensor_constant1 = None cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, _tensor_constant0, sym_size_int_1, sym_size_int, _tensor_constant1)); eq = true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = sym_size_int_1 = sym_size_int = _tensor_constant1 = None
getitem = cond[0]; cond = None getitem = cond[0]; cond = None
return getitem""", # noqa: B950 return getitem""", # noqa: B950
) )
@ -6209,7 +6209,7 @@ def forward(self, x_1):
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1) sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
true_graph_0 = self.true_graph_0 true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0 false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int, sym_size_int_1)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int = sym_size_int_1 = None cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int_1, sym_size_int)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int_1 = sym_size_int = None
getitem = cond[0]; cond = None getitem = cond[0]; cond = None
return getitem""", # noqa: B950 return getitem""", # noqa: B950
) )
@ -6558,14 +6558,14 @@ def forward(self, l_inp_, l_tmp_):
self.assertExpectedInline( self.assertExpectedInline(
backend.graphs[0].code.strip(), backend.graphs[0].code.strip(),
"""\ """\
def forward(self, s0 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor, L_self_num : torch.SymInt): def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor, L_self_num : torch.SymInt):
l_a_ = L_a_ l_a_ = L_a_
l_b_ = L_b_ l_b_ = L_b_
l_self_num = L_self_num l_self_num = L_self_num
tensor = torch.tensor([True]) tensor = torch.tensor([True])
cond_true_0 = self.cond_true_0 cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0 cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, (l_a_, l_b_, l_self_num, s0)); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = l_self_num = s0 = None cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, (l_a_, l_b_, l_self_num, s97)); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = l_self_num = s97 = None
getitem = cond[0]; cond = None getitem = cond[0]; cond = None
return (getitem,)""", # noqa: B950 return (getitem,)""", # noqa: B950
) )
@ -6753,10 +6753,10 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, x): def forward(self, x):
x: "f32[s0, 3]"; x: "f32[s35, 3]";
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0) sym_size_int_1: "Sym(s35)" = torch.ops.aten.sym_size.int(x, 0)
while_loop_cond_graph_0 = self.while_loop_cond_graph_0 while_loop_cond_graph_0 = self.while_loop_cond_graph_0
while_loop_body_graph_0 = self.while_loop_body_graph_0 while_loop_body_graph_0 = self.while_loop_body_graph_0
@ -6770,27 +6770,27 @@ class GraphModule(torch.nn.Module):
gt_1: "Sym(u1 > 0)" = getitem_2 > 0 gt_1: "Sym(u1 > 0)" = getitem_2 > 0
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u1 on node 'gt_1'"); gt_1 = _assert_scalar_default_1 = None _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u1 on node 'gt_1'"); gt_1 = _assert_scalar_default_1 = None
getitem_1: "f32[s0, 3]" = while_loop[1]; while_loop = None getitem_1: "f32[s35, 3]" = while_loop[1]; while_loop = None
add: "Sym(u1 + 1)" = getitem_2 + 1 add: "Sym(u1 + 1)" = getitem_2 + 1
add_1: "f32[s0, 3]" = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = None add_1: "f32[s35, 3]" = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = None
lt: "Sym(u1 < s0)" = getitem_2 < sym_size_int_1; sym_size_int_1 = None lt: "Sym(u1 < s35)" = getitem_2 < sym_size_int_1; sym_size_int_1 = None
mul: "Sym(2*u1)" = getitem_2 * 2; getitem_2 = None mul: "Sym(2*u1)" = getitem_2 * 2; getitem_2 = None
ones: "f32[2*u1]" = torch.ops.aten.ones.default([mul], device = device(type='cpu'), pin_memory = False); mul = None ones: "f32[2*u1]" = torch.ops.aten.ones.default([mul], device = device(type='cpu'), pin_memory = False); mul = None
return pytree.tree_unflatten((add, add_1, lt, ones), self._out_spec) return pytree.tree_unflatten((add, add_1, lt, ones), self._out_spec)
class while_loop_cond_graph_0(torch.nn.Module): class while_loop_cond_graph_0(torch.nn.Module):
def forward(self, it_1: "Sym(u0)", x_1: "f32[s0, 3]"): def forward(self, it_1: "Sym(u0)", x_1: "f32[s35, 3]"):
sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None sym_size_int: "Sym(s35)" = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None
lt: "Sym(u0 < s0)" = it_1 < sym_size_int; it_1 = sym_size_int = None lt: "Sym(u0 < s35)" = it_1 < sym_size_int; it_1 = sym_size_int = None
return lt return lt
class while_loop_body_graph_0(torch.nn.Module): class while_loop_body_graph_0(torch.nn.Module):
def forward(self, it_1: "Sym(u0)", x_1: "f32[s0, 3]"): def forward(self, it_1: "Sym(u0)", x_1: "f32[s35, 3]"):
clone: "f32[s0, 3]" = torch.ops.aten.clone.default(x_1); x_1 = None clone: "f32[s35, 3]" = torch.ops.aten.clone.default(x_1); x_1 = None
select: "f32[3]" = torch.ops.aten.select.int(clone, 0, it_1) select: "f32[3]" = torch.ops.aten.select.int(clone, 0, it_1)
select_1: "f32[3]" = torch.ops.aten.select.int(clone, 0, it_1) select_1: "f32[3]" = torch.ops.aten.select.int(clone, 0, it_1)
add: "f32[3]" = torch.ops.aten.add.Tensor(select_1, it_1); select_1 = None add: "f32[3]" = torch.ops.aten.add.Tensor(select_1, it_1); select_1 = None
@ -6820,12 +6820,12 @@ class GraphModule(torch.nn.Module):
normalize_gm(backend.graphs[0].print_readable(print_output=False)), normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_x_: "f32[s0, s1]"): def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]"):
l_x_ = L_x_ l_x_ = L_x_
cond_fn_0 = self.cond_fn_0 cond_fn_0 = self.cond_fn_0
body_fn_0 = self.body_fn_0 body_fn_0 = self.body_fn_0
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (0, l_x_), (s0, s1)); cond_fn_0 = body_fn_0 = l_x_ = s1 = None while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (0, l_x_), (s27, s77)); cond_fn_0 = body_fn_0 = l_x_ = s27 = None
getitem_4: "Sym(u1)" = while_loop[0] getitem_4: "Sym(u1)" = while_loop[0]
@ -6835,49 +6835,49 @@ class GraphModule(torch.nn.Module):
gt_1: "Sym(u1 > 0)" = getitem_4 > 0 gt_1: "Sym(u1 > 0)" = getitem_4 > 0
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u1 on node 'gt_1'"); gt_1 = _assert_scalar_default_1 = None _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u1 on node 'gt_1'"); gt_1 = _assert_scalar_default_1 = None
out_x: "f32[s0, s1]" = while_loop[1]; while_loop = None out_x: "f32[s77, s27]" = while_loop[1]; while_loop = None
add: "Sym(u1 + 1)" = getitem_4 + 1 add: "Sym(u1 + 1)" = getitem_4 + 1
add_1: "f32[s0, s1]" = getitem_4 + out_x; out_x = None add_1: "f32[s77, s27]" = getitem_4 + out_x; out_x = None
lt: "Sym(u1 < s0)" = getitem_4 < s0; s0 = None lt: "Sym(u1 < s77)" = getitem_4 < s77; s77 = None
mul: "Sym(2*u1)" = getitem_4 * 2; getitem_4 = None mul: "Sym(2*u1)" = getitem_4 * 2; getitem_4 = None
ones: "f32[2*u1]" = torch.ones(mul); mul = None ones: "f32[2*u1]" = torch.ones(mul); mul = None
return (add, add_1, lt, ones) return (add, add_1, lt, ones)
class cond_fn_0(torch.nn.Module): class cond_fn_0(torch.nn.Module):
def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s0, s1]", s0, s1): def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s77, s27]", s27, s77):
s0_1 = s0 s27_1 = s27
s1_1 = s1 s77_1 = s77
size = l_x_.size(); l_x_ = None size = l_x_.size(); l_x_ = None
getitem: "Sym(s0)" = size[0] getitem: "Sym(s77)" = size[0]
getitem_1: "Sym(s1)" = size[1]; size = getitem_1 = None getitem_1: "Sym(s27)" = size[1]; size = getitem_1 = None
lt: "Sym(u0 < s0)" = unbacked_symint < getitem; unbacked_symint = getitem = None lt: "Sym(u0 < s77)" = unbacked_symint < getitem; unbacked_symint = getitem = None
return lt return lt
class body_fn_0(torch.nn.Module): class body_fn_0(torch.nn.Module):
def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s0, s1]", s0, s1): def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s77, s27]", s27, s77):
s0_1 = s0 s27_1 = s27
s1_1 = s1 s77_1 = s77
x_clone: "f32[s0, s1]" = l_x_.clone() x_clone: "f32[s77, s27]" = l_x_.clone()
ge: "Sym(u0 >= 0)" = unbacked_symint >= 0 ge: "Sym(u0 >= 0)" = unbacked_symint >= 0
_check = torch._check(ge); ge = _check = None _check = torch._check(ge); ge = _check = None
size = l_x_.size(); l_x_ = None size = l_x_.size(); l_x_ = None
getitem: "Sym(s0)" = size[0] getitem: "Sym(s77)" = size[0]
getitem_1: "Sym(s1)" = size[1]; size = getitem_1 = None getitem_1: "Sym(s27)" = size[1]; size = getitem_1 = None
lt: "Sym(u0 < s0)" = unbacked_symint < getitem; getitem = None lt: "Sym(u0 < s77)" = unbacked_symint < getitem; getitem = None
_check_1 = torch._check(lt); lt = _check_1 = None _check_1 = torch._check(lt); lt = _check_1 = None
select: "f32[s1]" = x_clone.select(0, unbacked_symint) select: "f32[s27]" = x_clone.select(0, unbacked_symint)
select_1: "f32[s1]" = x_clone.select(0, unbacked_symint) select_1: "f32[s27]" = x_clone.select(0, unbacked_symint)
add: "f32[s1]" = select_1 + unbacked_symint; select_1 = None add: "f32[s27]" = select_1 + unbacked_symint; select_1 = None
copy_: "f32[s1]" = select.copy_(add); select = add = copy_ = None copy_: "f32[s27]" = select.copy_(add); select = add = copy_ = None
add_1: "Sym(u0 + 1)" = unbacked_symint + 1; unbacked_symint = None add_1: "Sym(u0 + 1)" = unbacked_symint + 1; unbacked_symint = None
return (add_1, x_clone) return (add_1, x_clone)
@ -7048,12 +7048,12 @@ class GraphModule(torch.nn.Module):
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, x): def forward(self, x):
x: "f32[s0, 3]"; x: "f32[s77, 3]";
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0) sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x, 0)
sin: "f32[s0, 3]" = torch.ops.aten.sin.default(x); x = None sin: "f32[s77, 3]" = torch.ops.aten.sin.default(x); x = None
while_loop_cond_graph_0 = self.while_loop_cond_graph_0 while_loop_cond_graph_0 = self.while_loop_cond_graph_0
while_loop_body_graph_0 = self.while_loop_body_graph_0 while_loop_body_graph_0 = self.while_loop_body_graph_0
@ -7065,19 +7065,19 @@ class GraphModule(torch.nn.Module):
getitem_9: "Sym(u8)" = while_loop[3] getitem_9: "Sym(u8)" = while_loop[3]
getitem_10: "Sym(u9)" = while_loop[4] getitem_10: "Sym(u9)" = while_loop[4]
getitem_5: "f32[s0, 3]" = while_loop[5]; while_loop = None getitem_5: "f32[s77, 3]" = while_loop[5]; while_loop = None
add: "Sym(u7 + 1)" = getitem_8 + 1 add: "Sym(u7 + 1)" = getitem_8 + 1
add_1: "Sym(u8 + 1)" = getitem_9 + 1 add_1: "Sym(u8 + 1)" = getitem_9 + 1
add_2: "Sym(u9 + 1)" = getitem_10 + 1 add_2: "Sym(u9 + 1)" = getitem_10 + 1
add_3: "f32[s0, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_8); getitem_8 = None add_3: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_8); getitem_8 = None
add_4: "f32[s0, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_9); getitem_9 = None add_4: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_9); getitem_9 = None
add_5: "f32[s0, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_10); getitem_10 = None add_5: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_10); getitem_10 = None
return pytree.tree_unflatten((getitem_6, getitem_7, add, add_1, add_2, add_3, add_4, add_5, getitem_5), self._out_spec) return pytree.tree_unflatten((getitem_6, getitem_7, add, add_1, add_2, add_3, add_4, add_5, getitem_5), self._out_spec)
class while_loop_cond_graph_0(torch.nn.Module): class while_loop_cond_graph_0(torch.nn.Module):
def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s0, 3]"): def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s77, 3]"):
mul: "Sym(u17*u18)" = arg2_1 * arg3_1; arg2_1 = arg3_1 = None mul: "Sym(u17*u18)" = arg2_1 * arg3_1; arg2_1 = arg3_1 = None
mul_1: "Sym(u17*u18*u19)" = mul * arg4_1; mul = arg4_1 = None mul_1: "Sym(u17*u18*u19)" = mul * arg4_1; mul = arg4_1 = None
mul_2: "Sym(u15*u16)" = arg0_1 * arg1_1; arg0_1 = arg1_1 = None mul_2: "Sym(u15*u16)" = arg0_1 * arg1_1; arg0_1 = arg1_1 = None
@ -7085,7 +7085,7 @@ class GraphModule(torch.nn.Module):
return lt return lt
class while_loop_body_graph_0(torch.nn.Module): class while_loop_body_graph_0(torch.nn.Module):
def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s0, 3]"): def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s77, 3]"):
add: "Sym(u15 + 1)" = arg0_1 + 1; arg0_1 = None add: "Sym(u15 + 1)" = arg0_1 + 1; arg0_1 = None
add_1: "Sym(u16 + 1)" = arg1_1 + 1; arg1_1 = None add_1: "Sym(u16 + 1)" = arg1_1 + 1; arg1_1 = None
@ -7093,7 +7093,7 @@ class GraphModule(torch.nn.Module):
add_3: "Sym(u18 + 1)" = arg3_1 + 1; arg3_1 = None add_3: "Sym(u18 + 1)" = arg3_1 + 1; arg3_1 = None
add_4: "Sym(u19 + 1)" = arg4_1 + 1; arg4_1 = None add_4: "Sym(u19 + 1)" = arg4_1 + 1; arg4_1 = None
add_5: "f32[s0, 3]" = torch.ops.aten.add.Tensor(arg5_1, 1); arg5_1 = None add_5: "f32[s77, 3]" = torch.ops.aten.add.Tensor(arg5_1, 1); arg5_1 = None
return (add, add_1, add_2, add_3, add_4, add_5) return (add, add_1, add_2, add_3, add_4, add_5)
""", # noqa: B950 """, # noqa: B950
) )
@ -7119,14 +7119,14 @@ class GraphModule(torch.nn.Module):
normalize_gm(backend.graphs[0].print_readable(print_output=False)), normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_x_: "f32[s0, s1]"): def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]"):
l_x_ = L_x_ l_x_ = L_x_
child: "f32[s0, s1]" = l_x_.sin(); l_x_ = None child: "f32[s77, s27]" = l_x_.sin(); l_x_ = None
cond_fn_0 = self.cond_fn_0 cond_fn_0 = self.cond_fn_0
body_fn_0 = self.body_fn_0 body_fn_0 = self.body_fn_0
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (s0, s1, 2, 2, 3, child), (s0, s1)); cond_fn_0 = body_fn_0 = s0 = s1 = child = None while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (s77, s27, 2, 2, 3, child), (s27, s77)); cond_fn_0 = body_fn_0 = s77 = s27 = child = None
getitem_10: "Sym(u5)" = while_loop[0] getitem_10: "Sym(u5)" = while_loop[0]
getitem_11: "Sym(u6)" = while_loop[1] getitem_11: "Sym(u6)" = while_loop[1]
@ -7134,21 +7134,21 @@ class GraphModule(torch.nn.Module):
getitem_13: "Sym(u8)" = while_loop[3] getitem_13: "Sym(u8)" = while_loop[3]
getitem_14: "Sym(u9)" = while_loop[4] getitem_14: "Sym(u9)" = while_loop[4]
out_x: "f32[s0, s1]" = while_loop[5]; while_loop = None out_x: "f32[s77, s27]" = while_loop[5]; while_loop = None
add: "Sym(u7 + 1)" = getitem_12 + 1 add: "Sym(u7 + 1)" = getitem_12 + 1
add_1: "Sym(u8 + 1)" = getitem_13 + 1 add_1: "Sym(u8 + 1)" = getitem_13 + 1
add_2: "Sym(u9 + 1)" = getitem_14 + 1 add_2: "Sym(u9 + 1)" = getitem_14 + 1
add_3: "f32[s0, s1]" = getitem_12 + out_x; getitem_12 = None add_3: "f32[s77, s27]" = getitem_12 + out_x; getitem_12 = None
add_4: "f32[s0, s1]" = getitem_13 + out_x; getitem_13 = None add_4: "f32[s77, s27]" = getitem_13 + out_x; getitem_13 = None
add_5: "f32[s0, s1]" = getitem_14 + out_x; getitem_14 = None add_5: "f32[s77, s27]" = getitem_14 + out_x; getitem_14 = None
return (getitem_10, getitem_11, add, add_1, add_2, add_3, add_4, add_5, out_x) return (getitem_10, getitem_11, add, add_1, add_2, add_3, add_4, add_5, out_x)
class cond_fn_0(torch.nn.Module): class cond_fn_0(torch.nn.Module):
def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s0, s1]", s0, s1): def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s77, s27]", s27, s77):
s0_1 = s0 s27_1 = s27
s1_1 = s1 s77_1 = s77
mul: "Sym(u2*u3)" = unbacked_symint_1 * unbacked_symint_2; unbacked_symint_1 = unbacked_symint_2 = None mul: "Sym(u2*u3)" = unbacked_symint_1 * unbacked_symint_2; unbacked_symint_1 = unbacked_symint_2 = None
mul_1: "Sym(u2*u3*u4)" = mul * unbacked_symint_3; mul = unbacked_symint_3 = None mul_1: "Sym(u2*u3*u4)" = mul * unbacked_symint_3; mul = unbacked_symint_3 = None
@ -7157,9 +7157,9 @@ class GraphModule(torch.nn.Module):
return lt return lt
class body_fn_0(torch.nn.Module): class body_fn_0(torch.nn.Module):
def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s0, s1]", s0, s1): def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s77, s27]", s27, s77):
s0_1 = s0 s27_1 = s27
s1_1 = s1 s77_1 = s77
add: "Sym(u0 + 1)" = unbacked_symint + 1; unbacked_symint = None add: "Sym(u0 + 1)" = unbacked_symint + 1; unbacked_symint = None
add_1: "Sym(u1 + 1)" = unbacked_symint_0 + 1; unbacked_symint_0 = None add_1: "Sym(u1 + 1)" = unbacked_symint_0 + 1; unbacked_symint_0 = None
@ -7168,7 +7168,7 @@ class GraphModule(torch.nn.Module):
add_3: "Sym(u3 + 1)" = unbacked_symint_2 + 1; unbacked_symint_2 = None add_3: "Sym(u3 + 1)" = unbacked_symint_2 + 1; unbacked_symint_2 = None
add_4: "Sym(u4 + 1)" = unbacked_symint_3 + 1; unbacked_symint_3 = None add_4: "Sym(u4 + 1)" = unbacked_symint_3 + 1; unbacked_symint_3 = None
child_1: "f32[s0, s1]" = child + 1; child = None child_1: "f32[s77, s27]" = child + 1; child = None
return (add, add_1, add_2, add_3, add_4, child_1) return (add, add_1, add_2, add_3, add_4, child_1)
""", # noqa: B950 """, # noqa: B950
) )
@ -7336,30 +7336,30 @@ class GraphModule(torch.nn.Module):
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, x, y, z): def forward(self, x, y, z):
x: "f32[s0, 3]"; y: "f32[s1]"; z: "f32[s0, 3]"; x: "f32[s35, 3]"; y: "f32[s58]"; z: "f32[s35, 3]";
x, y, z, = fx_pytree.tree_flatten_spec(([x, y, z], {}), self._in_spec) x, y, z, = fx_pytree.tree_flatten_spec(([x, y, z], {}), self._in_spec)
sym_size_int_3: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0) sym_size_int_3: "Sym(s35)" = torch.ops.aten.sym_size.int(x, 0)
sym_size_int_4: "Sym(s1)" = torch.ops.aten.sym_size.int(y, 0); y = None sym_size_int_4: "Sym(s58)" = torch.ops.aten.sym_size.int(y, 0); y = None
gt: "Sym(s0 > 5)" = sym_size_int_3 > 5 gt: "Sym(s35 > 5)" = sym_size_int_3 > 5
true_graph_0 = self.true_graph_0 true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0 false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (x, sym_size_int_4, sym_size_int_3, z)); gt = true_graph_0 = false_graph_0 = x = sym_size_int_4 = sym_size_int_3 = z = None cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (x, sym_size_int_4, sym_size_int_3, z)); gt = true_graph_0 = false_graph_0 = x = sym_size_int_4 = sym_size_int_3 = z = None
getitem: "f32[s0, 3]" = cond[0]; cond = None getitem: "f32[s35, 3]" = cond[0]; cond = None
return pytree.tree_unflatten((getitem,), self._out_spec) return pytree.tree_unflatten((getitem,), self._out_spec)
class true_graph_0(torch.nn.Module): class true_graph_0(torch.nn.Module):
def forward(self, x: "f32[s0, 3]", sym_size_int_4: "Sym(s1)", sym_size_int_3: "Sym(s0)", z: "f32[s0, 3]"): def forward(self, x: "f32[s35, 3]", sym_size_int_4: "Sym(s58)", sym_size_int_3: "Sym(s35)", z: "f32[s35, 3]"):
add: "f32[s0, 3]" = torch.ops.aten.add.Tensor(x, sym_size_int_4); x = sym_size_int_4 = None add: "f32[s35, 3]" = torch.ops.aten.add.Tensor(x, sym_size_int_4); x = sym_size_int_4 = None
return (add,) return (add,)
class false_graph_0(torch.nn.Module): class false_graph_0(torch.nn.Module):
def forward(self, x: "f32[s0, 3]", sym_size_int_4: "Sym(s1)", sym_size_int_3: "Sym(s0)", z: "f32[s0, 3]"): def forward(self, x: "f32[s35, 3]", sym_size_int_4: "Sym(s58)", sym_size_int_3: "Sym(s35)", z: "f32[s35, 3]"):
mul: "f32[s0, 3]" = torch.ops.aten.mul.Tensor(z, sym_size_int_3); z = sym_size_int_3 = None mul: "f32[s35, 3]" = torch.ops.aten.mul.Tensor(z, sym_size_int_3); z = sym_size_int_3 = None
add: "f32[s0, 3]" = torch.ops.aten.add.Tensor(x, mul); x = mul = None add: "f32[s35, 3]" = torch.ops.aten.add.Tensor(x, mul); x = mul = None
return (add,) return (add,)
""", # noqa: B950 """, # noqa: B950
) )
@ -7522,7 +7522,7 @@ class GraphModule(torch.nn.Module):
normalize_gm(bk.graphs[0].print_readable(print_output=False)), normalize_gm(bk.graphs[0].print_readable(print_output=False)),
"""\ """\
class GraphModule(torch.nn.Module): class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_y_: "f32[s0, s1]", L_z_: "f32[s0, s1]", L_x_: "f32[s0, s1]"): def forward(self, s17: "Sym(s17)", s94: "Sym(s94)", L_y_: "f32[s17, s94]", L_z_: "f32[s17, s94]", L_x_: "f32[s17, s94]"):
l_y_ = L_y_ l_y_ = L_y_
l_z_ = L_z_ l_z_ = L_z_
l_x_ = L_x_ l_x_ = L_x_
@ -7532,39 +7532,39 @@ class GraphModule(torch.nn.Module):
cond_true_0 = self.cond_true_0 cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0 cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(gt, cond_true_0, cond_false_0, (l_x_, s1, s0, s0, l_z_)); gt = cond_true_0 = cond_false_0 = l_x_ = s1 = s0 = l_z_ = None cond = torch.ops.higher_order.cond(gt, cond_true_0, cond_false_0, (l_x_, s94, s17, s17, l_z_)); gt = cond_true_0 = cond_false_0 = l_x_ = s94 = s17 = l_z_ = None
getitem_5: "f32[u0, s1]" = cond[0] getitem_5: "f32[u0, s94]" = cond[0]
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(getitem_5, 0); getitem_5 = None sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(getitem_5, 0); getitem_5 = None
_check_is_size = torch._check_is_size(sym_size_int); _check_is_size = None _check_is_size = torch._check_is_size(sym_size_int); _check_is_size = None
ge: "Sym(u0 >= 0)" = sym_size_int >= 0; sym_size_int = None ge: "Sym(u0 >= 0)" = sym_size_int >= 0; sym_size_int = None
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
ret: "f32[u0, s1]" = cond[0]; cond = None ret: "f32[u0, s94]" = cond[0]; cond = None
sum_2: "f32[]" = l_y_.sum(); l_y_ = None sum_2: "f32[]" = l_y_.sum(); l_y_ = None
sub: "f32[u0, s1]" = sum_2 - ret; sum_2 = ret = None sub: "f32[u0, s94]" = sum_2 - ret; sum_2 = ret = None
return (sub,) return (sub,)
class cond_true_0(torch.nn.Module): class cond_true_0(torch.nn.Module):
def forward(self, l_x_, s1, s0_true_branch, getitem_2_false_branch, l_z__false_branch): def forward(self, l_x_, s94, s17_true_branch, getitem_2_false_branch, l_z__false_branch):
l_x__1 = l_x_ l_x__1 = l_x_
s1_1 = s1 s94_1 = s94
add: "f32[s0, s1]" = l_x__1 + s0_true_branch; l_x__1 = s0_true_branch = None add: "f32[s17, s94]" = l_x__1 + s17_true_branch; l_x__1 = s17_true_branch = None
getitem: "f32[s0 - 2, s1]" = add[slice(2, None, None)]; add = None getitem: "f32[s17 - 2, s94]" = add[slice(2, None, None)]; add = None
clone: "f32[s0 - 2, s1]" = getitem.clone(); getitem = None clone: "f32[s17 - 2, s94]" = getitem.clone(); getitem = None
return (clone,) return (clone,)
class cond_false_0(torch.nn.Module): class cond_false_0(torch.nn.Module):
def forward(self, l_x_, s1, s0_true_branch, getitem_2_false_branch, l_z__false_branch): def forward(self, l_x_, s94, s17_true_branch, getitem_2_false_branch, l_z__false_branch):
l_x__1 = l_x_ l_x__1 = l_x_
s1_1 = s1 s94_1 = s94
mul: "f32[s0, s1]" = getitem_2_false_branch * l_z__false_branch; getitem_2_false_branch = l_z__false_branch = None mul: "f32[s17, s94]" = getitem_2_false_branch * l_z__false_branch; getitem_2_false_branch = l_z__false_branch = None
add: "f32[s0, s1]" = l_x__1 + mul; l_x__1 = mul = None add: "f32[s17, s94]" = l_x__1 + mul; l_x__1 = mul = None
getitem: "f32[2, s1]" = add[slice(None, 2, None)]; add = None getitem: "f32[2, s94]" = add[slice(None, 2, None)]; add = None
clone: "f32[2, s1]" = getitem.clone(); getitem = None clone: "f32[2, s94]" = getitem.clone(); getitem = None
return (clone,) return (clone,)
""", # noqa: B950 """, # noqa: B950
) )

View File

@ -403,10 +403,10 @@ arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None
self.assertExpectedInline( self.assertExpectedInline(
post_grad_graphs, post_grad_graphs,
"""\ """\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu", arg3_1: "f32[s0][1]cpu", arg4_1: "f32[s0][1]cpu", arg5_1: "f32[s0][1]cpu"): def forward(self, arg0_1: "Sym(s72)", arg1_1: "f32[s72][1]cpu", arg2_1: "f32[s72][1]cpu", arg3_1: "f32[s72][1]cpu", arg4_1: "f32[s72][1]cpu", arg5_1: "f32[s72][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg3_1, [arg4_1, arg5_1], arg2_1, 2, arg1_1); arg4_1 = arg5_1 = arg1_1 = foo_default = None foo_default = torch.ops.mylib.foo.default(arg3_1, [arg4_1, arg5_1], arg2_1, 2, arg1_1); arg4_1 = arg5_1 = arg1_1 = foo_default = None
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy_ = None copy_: "f32[s72][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy_ = None
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg3_1, arg3_1); arg3_1 = copy__1 = None copy__1: "f32[s72][1]cpu" = torch.ops.aten.copy_.default(arg3_1, arg3_1); arg3_1 = copy__1 = None
return ()""", # noqa: B950 return ()""", # noqa: B950
ignore_comments=True, ignore_comments=True,
ignore_empty_lines=True, ignore_empty_lines=True,
@ -563,13 +563,13 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3
self.assertExpectedInline( self.assertExpectedInline(
graph_aot, graph_aot,
"""\ """\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu"): def forward(self, arg0_1: "Sym(s17)", arg1_1: "f32[s17][1]cpu", arg2_1: "f32[s17][1]cpu"):
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg2_1, arg1_1]) auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg2_1, arg1_1])
getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1] getitem_1: "f32[s17][1]cpu" = auto_functionalized_v2[1]
getitem_2: "f32[s0][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None getitem_2: "f32[s17][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2) add: "f32[s17][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2)
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy_ = None copy_: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy_ = None
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_1); arg2_1 = getitem_1 = copy__1 = None copy__1: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_1); arg2_1 = getitem_1 = copy__1 = None
return (add,)""", # noqa: B950 return (add,)""", # noqa: B950
ignore_comments=True, ignore_comments=True,
ignore_empty_lines=True, ignore_empty_lines=True,
@ -595,11 +595,11 @@ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"):
self.assertExpectedInline( self.assertExpectedInline(
graph_inductor, graph_inductor,
"""\ """\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu"): def forward(self, arg0_1: "Sym(s17)", arg1_1: "f32[s17][1]cpu", arg2_1: "f32[s17][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg2_1, arg1_1); foo_default = None foo_default = torch.ops.mylib.foo.default(arg2_1, arg1_1); foo_default = None
add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(arg2_1, arg1_1) add: "f32[s17][1]cpu" = torch.ops.aten.add.Tensor(arg2_1, arg1_1)
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None copy_: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None copy__1: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None
return (add,)""", return (add,)""",
ignore_comments=True, ignore_comments=True,
ignore_empty_lines=True, ignore_empty_lines=True,
@ -663,10 +663,10 @@ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"):
self.assertExpectedInline( self.assertExpectedInline(
graph_aot, graph_aot,
"""\ """\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"): def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"):
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg1_1]) auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg1_1])
getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None getitem_1: "f32[s77][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy_ = None copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy_ = None
return ()""", # noqa: B950 return ()""", # noqa: B950
ignore_comments=True, ignore_comments=True,
ignore_empty_lines=True, ignore_empty_lines=True,
@ -691,11 +691,11 @@ def forward(self, arg0_1: "f32[2][1]cpu"):
self.assertExpectedInline( self.assertExpectedInline(
graph_inductor, graph_inductor,
"""\ """\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"): def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"):
as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 0) as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 0)
as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 1) as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 1)
foo_default = torch.ops.mylib.foo.default(as_strided_default, as_strided_default_1); as_strided_default = as_strided_default_1 = foo_default = None foo_default = torch.ops.mylib.foo.default(as_strided_default, as_strided_default_1); as_strided_default = as_strided_default_1 = foo_default = None
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
return ()""", # noqa: B950 return ()""", # noqa: B950
ignore_comments=True, ignore_comments=True,
ignore_empty_lines=True, ignore_empty_lines=True,
@ -1291,14 +1291,14 @@ def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"):
self.assertExpectedInline( self.assertExpectedInline(
graph_aot, graph_aot,
"""\ """\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0, s0][s0, 1]cpu"): def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77, s77][s77, 1]cpu"):
floordiv: "Sym(0)" = 0 // arg0_1; arg0_1 = None floordiv: "Sym(0)" = 0 // arg0_1; arg0_1 = None
add_6: "Sym(2)" = floordiv + 2 add_6: "Sym(2)" = floordiv + 2
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_slice_dim = 0, _x_slice_start = floordiv, _x_slice_end = add_6, _y_base_index = 0, _y_slice_dim = 1, _y_slice_start = 3, _y_slice_end = 4, _all_bases = [arg1_1]); floordiv = add_6 = None auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_slice_dim = 0, _x_slice_start = floordiv, _x_slice_end = add_6, _y_base_index = 0, _y_slice_dim = 1, _y_slice_start = 3, _y_slice_end = 4, _all_bases = [arg1_1]); floordiv = add_6 = None
getitem_1: "f32[s0, s0][s0, 1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None getitem_1: "f32[s77, s77][s77, 1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None
copy_: "f32[s0, s0][s0, 1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None copy_: "f32[s77, s77][s77, 1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None
slice_3: "f32[2, s0][s0, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_1, 0, 0, 2) slice_3: "f32[2, s77][s77, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_1, 0, 0, 2)
slice_4: "f32[s0, 1][s0, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_1, 1, 3, 4); getitem_1 = None slice_4: "f32[s77, 1][s77, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_1, 1, 3, 4); getitem_1 = None
return (slice_3, slice_4)""", # noqa: B950 return (slice_3, slice_4)""", # noqa: B950
ignore_comments=True, ignore_comments=True,
ignore_empty_lines=True, ignore_empty_lines=True,
@ -1324,13 +1324,13 @@ def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"):
self.assertExpectedInline( self.assertExpectedInline(
graph_inductor, graph_inductor,
"""\ """\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0, s0][s0, 1]cpu"): def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77, s77][s77, 1]cpu"):
slice_tensor: "f32[2, s0][s0, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, 0, 2) slice_tensor: "f32[2, s77][s77, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, 0, 2)
slice_tensor_1: "f32[s0, 1][s0, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 1, 3, 4) slice_tensor_1: "f32[s77, 1][s77, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 1, 3, 4)
foo_default = torch.ops.mylib.foo.default(slice_tensor, slice_tensor_1); slice_tensor = slice_tensor_1 = foo_default = None foo_default = torch.ops.mylib.foo.default(slice_tensor, slice_tensor_1); slice_tensor = slice_tensor_1 = foo_default = None
copy_: "f32[s0, s0][s0, 1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None copy_: "f32[s77, s77][s77, 1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None
slice_3: "f32[2, s0][s0, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, 0, 2) slice_3: "f32[2, s77][s77, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, 0, 2)
slice_4: "f32[s0, 1][s0, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 1, 3, 4); arg1_1 = None slice_4: "f32[s77, 1][s77, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 1, 3, 4); arg1_1 = None
return (slice_3, slice_4)""", # noqa: B950 return (slice_3, slice_4)""", # noqa: B950
ignore_comments=True, ignore_comments=True,
ignore_empty_lines=True, ignore_empty_lines=True,
@ -1470,18 +1470,18 @@ def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"):
self.assertExpectedInline( self.assertExpectedInline(
graph_aot, graph_aot,
"""\ """\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"): def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"):
clone: "f32[s0][1]cpu" = torch.ops.aten.clone.default(arg1_1) clone: "f32[s77][1]cpu" = torch.ops.aten.clone.default(arg1_1)
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(clone); clone = None nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0) sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
_to_copy: "f32[u0, 1][1, u0]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None _to_copy: "f32[u0, 1][1, u0]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_alias = True, _y_base_index = 1, _y_alias = True, _all_bases = [arg1_1, _to_copy]); _to_copy = None auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_alias = True, _y_base_index = 1, _y_alias = True, _all_bases = [arg1_1, _to_copy]); _to_copy = None
getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1] getitem_1: "f32[s77][1]cpu" = auto_functionalized_v2[1]
getitem_2: "f32[u0, 1][1, u0]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None getitem_2: "f32[u0, 1][1, u0]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None
alias_1: "f32[s0][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None alias_1: "f32[s77][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None
slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None
return (alias_1, slice_2)""", # noqa: B950 return (alias_1, slice_2)""", # noqa: B950
ignore_comments=True, ignore_comments=True,
@ -1517,16 +1517,16 @@ def forward(self, arg0_1: "f32[2][1]cpu"):
self.assertExpectedInline( self.assertExpectedInline(
graph_inductor, graph_inductor,
"""\ """\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"): def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"):
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(arg1_1) nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(arg1_1)
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0) sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
convert_element_type: "f32[u0, 1][1, u0]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None convert_element_type: "f32[u0, 1][1, u0]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None
alias_default: "f32[s0][1]cpu" = torch.ops.aten.alias.default(arg1_1) alias_default: "f32[s77][1]cpu" = torch.ops.aten.alias.default(arg1_1)
alias_default_1: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.alias.default(convert_element_type) alias_default_1: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.alias.default(convert_element_type)
foo_default = torch.ops.mylib.foo.default(alias_default, alias_default_1); alias_default = alias_default_1 = foo_default = None foo_default = torch.ops.mylib.foo.default(alias_default, alias_default_1); alias_default = alias_default_1 = foo_default = None
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None
slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None
return (arg1_1, slice_2)""", # noqa: B950 return (arg1_1, slice_2)""", # noqa: B950
ignore_comments=True, ignore_comments=True,

View File

@ -53,11 +53,13 @@ class TestMemoryPlanning(TestCase):
result, code = run_and_get_cpp_code(compiled, *args) result, code = run_and_get_cpp_code(compiled, *args)
FileCheck().check( FileCheck().check(
"pool1 = empty_strided_" + GPU_TYPE + "((4*s0*s1 + align(4*s0*s0), ), (1, )" "pool1 = empty_strided_"
+ GPU_TYPE
+ "((4*s27*s77 + align(4*s77*s77), ), (1, )"
).check_next( ).check_next(
"buf0 = alloc_from_pool(pool1, 0, torch.float32, (s0, s0), (s0, 1))" "buf0 = alloc_from_pool(pool1, 0, torch.float32, (s77, s77), (s77, 1))"
).check( ).check(
"buf1 = alloc_from_pool(pool1, align(4*s0*s0)," "buf1 = alloc_from_pool(pool1, align(4*s77*s77),"
).run( ).run(
code code
) )
@ -95,7 +97,7 @@ class TestMemoryPlanning(TestCase):
) )
FileCheck().check( FileCheck().check(
"int64_t int_array_2[] = {24L + align(12L*s0), };" "int64_t int_array_2[] = {24L + align(12L*s77), };"
).check_next("int64_t int_array_3[] = {1L, };").check_next( ).check_next("int64_t int_array_3[] = {1L, };").check_next(
"AtenTensorHandle pool1_handle;" "AtenTensorHandle pool1_handle;"
).check_next( ).check_next(
@ -103,7 +105,7 @@ class TestMemoryPlanning(TestCase):
).check_next( ).check_next(
"RAIIAtenTensorHandle pool1(pool1_handle);" "RAIIAtenTensorHandle pool1(pool1_handle);"
).check_next( ).check_next(
"int64_t int_array_4[] = {s0, 3L};" "int64_t int_array_4[] = {s77, 3L};"
).check_next( ).check_next(
"int64_t int_array_5[] = {3L, 1L};" "int64_t int_array_5[] = {3L, 1L};"
).check_next( ).check_next(

View File

@ -11819,7 +11819,7 @@ class CommonTemplate:
# 'i1 + 3 * i0' is cached. # 'i1 + 3 * i0' is cached.
self.assertTrue( self.assertTrue(
"i0 + 2 * i1" in mul_buf.data.inner_fn_str() "i0 + 2 * i1" in mul_buf.data.inner_fn_str()
or "i0 + i1 * s1" in mul_buf.data.inner_fn_str() or "i0 + i1 * s64" in mul_buf.data.inner_fn_str()
) )
with add_scheduler_init_hook(hook_fn): with add_scheduler_init_hook(hook_fn):
@ -12548,7 +12548,7 @@ class CommonTemplate:
torch.testing.assert_close(ref, act, atol=1e-3, rtol=1e-3) torch.testing.assert_close(ref, act, atol=1e-3, rtol=1e-3)
if is_dynamic_shape_enabled(): if is_dynamic_shape_enabled():
size_assert_pattern = r"assert_size_stride.[a-z]+[0-9]+, .2, 3, s1, s2, s2., .3\*s1\*s2\*s2, s1\*s2\*s2, 1, s1\*s2, s1.." # noqa: B950 size_assert_pattern = r"assert_size_stride.[a-z]+[0-9]+, .2, 3, s12, s80, s80., .3\*s12\*s80\*s80, s12\*s80\*s80, 1, s12\*s80, s1.." # noqa: B950
else: else:
size_assert_pattern = r"assert_size_stride.[a-z]+[0-9]+, .2, 3, 16, 32, 32., .49152, 16384, 1, 512, 16.." size_assert_pattern = r"assert_size_stride.[a-z]+[0-9]+, .2, 3, 16, 32, 32., .49152, 16384, 1, 512, 16.."
FileCheck().check_regex(size_assert_pattern).run(code) FileCheck().check_regex(size_assert_pattern).run(code)
@ -12567,8 +12567,8 @@ class CommonTemplate:
code = run_and_get_triton_code(f, x) code = run_and_get_triton_code(f, x)
if is_dynamic_shape_enabled(): if is_dynamic_shape_enabled():
FileCheck().check("assert_size_stride(buf1, (s0, s1), (s1, 1))").check( FileCheck().check("assert_size_stride(buf1, (s77, s27), (s27, 1))").check(
"assert_size_stride(buf2, (s0, s1), (s1, 1))" "assert_size_stride(buf2, (s77, s27), (s27, 1))"
).run(code) ).run(code)
else: else:
FileCheck().check("assert_size_stride(buf1, (16, 32), (32, 1))").check( FileCheck().check("assert_size_stride(buf1, (16, 32), (32, 1))").check(

View File

@ -456,7 +456,7 @@ def forward(self, x_1, output_1):
self.assertIn("output_handles[0] = ", code) self.assertIn("output_handles[0] = ", code)
self.assertIn("output_handles[1] = ", code) self.assertIn("output_handles[1] = ", code)
else: else:
self.assertIn("return (buf0, s0, )", code) self.assertIn("return (buf0, s92, )", code)
else: else:
self.assertIn( self.assertIn(
"output_handles[0] = " "output_handles[0] = "

View File

@ -393,7 +393,7 @@ class TestPySymInt(TestCase):
self.assertEqual(res_and, 0b1000) self.assertEqual(res_and, 0b1000)
self.assertIsInstance(res_and, torch.SymInt, msg=type(res_and)) self.assertIsInstance(res_and, torch.SymInt, msg=type(res_and))
self.assertExpectedInline( self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_and(s0, s1), 8)""" str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_and(s97, s26), 8)"""
) )
a1 = create_symint(shape_env, 3) a1 = create_symint(shape_env, 3)
@ -415,7 +415,7 @@ class TestPySymInt(TestCase):
self.assertEqual(res_or, 0b1110) self.assertEqual(res_or, 0b1110)
self.assertIsInstance(res_or, torch.SymInt, msg=type(res_or)) self.assertIsInstance(res_or, torch.SymInt, msg=type(res_or))
self.assertExpectedInline( self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_or(s0, s1), 14)""" str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_or(s97, s26), 14)"""
) )
def test_stride(self): def test_stride(self):
@ -497,7 +497,7 @@ class TestPySymInt(TestCase):
shape_env = ShapeEnv() shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2) a0 = create_symint(shape_env, 2)
self.assertEqual(guard_int(a0), 2) self.assertEqual(guard_int(a0), 2)
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""") self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""")
def test_sym_sum(self): def test_sym_sum(self):
shape_env = ShapeEnv() shape_env = ShapeEnv()
@ -512,7 +512,7 @@ class TestPySymInt(TestCase):
shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True) shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True)
s0 = create_symint(shape_env, 2) s0 = create_symint(shape_env, 2)
self.assertEqual(guard_int(s0), 2) self.assertEqual(guard_int(s0), 2)
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""") self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""")
shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True) shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True)
s0 = create_symint(shape_env, 2) s0 = create_symint(shape_env, 2)
@ -520,7 +520,7 @@ class TestPySymInt(TestCase):
self.assertEqual(len(shape_env.guards), 0) self.assertEqual(len(shape_env.guards), 0)
self.assertExpectedInline( self.assertExpectedInline(
str([ra.expr for ra in shape_env.deferred_runtime_asserts[None]]), str([ra.expr for ra in shape_env.deferred_runtime_asserts[None]]),
"""[Eq(s0, 2)]""", """[Eq(s97, 2)]""",
) )
def test_sym_int(self): def test_sym_int(self):
@ -529,14 +529,14 @@ class TestPySymInt(TestCase):
r = sym_int(a0) r = sym_int(a0)
self.assertEqual(r, 5) self.assertEqual(r, 5)
self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 5)""") self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 5)""")
a1 = create_symint(shape_env, 7) a1 = create_symint(shape_env, 7)
r = sym_int(a1 / 2) r = sym_int(a1 / 2)
self.assertEqual(guard_int(r), 3) self.assertEqual(guard_int(r), 3)
self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline( self.assertExpectedInline(
str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s1, 2)), 3)""" str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s26, 2)), 3)"""
) )
a3 = create_symint(shape_env, 3) a3 = create_symint(shape_env, 3)
@ -544,7 +544,7 @@ class TestPySymInt(TestCase):
self.assertEqual(guard_int(r), 6) self.assertEqual(guard_int(r), 6)
self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline( self.assertExpectedInline(
str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s2)), 6)""" str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s57)), 6)"""
) )
def test_sym_log2(self): def test_sym_log2(self):
@ -554,7 +554,7 @@ class TestPySymInt(TestCase):
self.assertEqual(r, 2.0) self.assertEqual(r, 2.0)
self.assertIsInstance(r, torch.SymFloat, msg=type(r)) self.assertIsInstance(r, torch.SymFloat, msg=type(r))
self.assertExpectedInline( self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_log2(ToFloat(s0)), 2.0)""" str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_log2(ToFloat(s97)), 2.0)"""
) )
def test_sym_sqrt(self): def test_sym_sqrt(self):
@ -564,7 +564,7 @@ class TestPySymInt(TestCase):
self.assertEqual(r, 2) self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymFloat, msg=type(r)) self.assertIsInstance(r, torch.SymFloat, msg=type(r))
self.assertExpectedInline( self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(ToFloat(s0)), 2.0)""" str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(ToFloat(s97)), 2.0)"""
) )
def test_sym_floor(self): def test_sym_floor(self):
@ -575,14 +575,14 @@ class TestPySymInt(TestCase):
self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline( self.assertExpectedInline(
str(shape_env.guards[0][0]), str(shape_env.guards[0][0]),
"""Eq(FloorToInt(IntTrueDiv(s0, 2)), 2)""", """Eq(FloorToInt(IntTrueDiv(s97, 2)), 2)""",
) )
r = math.floor(3.0 * a0) r = math.floor(3.0 * a0)
self.assertEqual(r, 15) self.assertEqual(r, 15)
self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline( self.assertExpectedInline(
str(shape_env.guards[1][0]), str(shape_env.guards[1][0]),
"""Eq(FloorToInt(3.0*ToFloat(s0)), 15)""", """Eq(FloorToInt(3.0*ToFloat(s97)), 15)""",
) )
def test_sym_trunc(self): def test_sym_trunc(self):
@ -592,14 +592,14 @@ class TestPySymInt(TestCase):
self.assertEqual(r, 2) self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline( self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s0, 2)), 2)""" str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s97, 2)), 2)"""
) )
r = torch.sym_int(torch.sym_sqrt(a0)) r = torch.sym_int(torch.sym_sqrt(a0))
self.assertEqual(r, 2) self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline( self.assertExpectedInline(
str(shape_env.guards[1][0]), str(shape_env.guards[1][0]),
"""Eq(TruncToInt(OpaqueUnaryFn_sqrt(ToFloat(s0))), 2)""", """Eq(TruncToInt(OpaqueUnaryFn_sqrt(ToFloat(s97))), 2)""",
) )
def test_sym_ceil(self): def test_sym_ceil(self):
@ -610,7 +610,7 @@ class TestPySymInt(TestCase):
self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline( self.assertExpectedInline(
str(shape_env.guards[0][0]), str(shape_env.guards[0][0]),
"""Eq(CeilToInt(IntTrueDiv(s0, 2)), 3)""", """Eq(CeilToInt(IntTrueDiv(s97, 2)), 3)""",
) )
r1 = 3.0 * a0 r1 = 3.0 * a0
r = math.floor(r1) r = math.floor(r1)
@ -618,7 +618,7 @@ class TestPySymInt(TestCase):
self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline( self.assertExpectedInline(
str(shape_env.guards[1][0]), str(shape_env.guards[1][0]),
"""Eq(FloorToInt(3.0*ToFloat(s0)), 15)""", """Eq(FloorToInt(3.0*ToFloat(s97)), 15)""",
) )
def test_sym_ite(self): def test_sym_ite(self):
@ -638,7 +638,7 @@ class TestPySymInt(TestCase):
self.assertEqual(type(t), type(r3)) self.assertEqual(type(t), type(r3))
self.assertExpectedInline( self.assertExpectedInline(
str(shape_env.guards[0][0]), str(shape_env.guards[0][0]),
"""Eq(Piecewise((s0, Eq(s0, 5)), (s1, True)), 5)""", """Eq(Piecewise((s97, Eq(s97, 5)), (s26, True)), 5)""",
) )
b4 = f == 5 b4 = f == 5
r4 = torch.sym_ite(b4, t, f) r4 = torch.sym_ite(b4, t, f)
@ -647,7 +647,7 @@ class TestPySymInt(TestCase):
self.assertEqual(type(f), type(r4)) self.assertEqual(type(f), type(r4))
self.assertExpectedInline( self.assertExpectedInline(
str(shape_env.guards[1][0]), str(shape_env.guards[1][0]),
"""Eq(Piecewise((s0, Eq(s1, 5)), (s1, True)), 4)""", """Eq(Piecewise((s97, Eq(s26, 5)), (s26, True)), 4)""",
) )
def test_tracing_sym_ite(self): def test_tracing_sym_ite(self):
@ -679,7 +679,7 @@ def forward(self, x_1):
shape_env = ShapeEnv() shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2) a0 = create_symint(shape_env, 2)
int(a0) int(a0)
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""") self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""")
def test_data_dependent_guard(self): def test_data_dependent_guard(self):
shape_env = ShapeEnv() shape_env = ShapeEnv()
@ -710,7 +710,7 @@ def forward(self, x_1):
self.assertTrue(expect_true(i0 < s0)) self.assertTrue(expect_true(i0 < s0))
self.assertExpectedInline( self.assertExpectedInline(
str([ra.expr for ra in shape_env.deferred_runtime_asserts[i0.node.expr]]), str([ra.expr for ra in shape_env.deferred_runtime_asserts[i0.node.expr]]),
"""[u0 < s0]""", """[u0 < s97]""",
) )
self.assertTrue(i0 < s0) self.assertTrue(i0 < s0)
self.assertTrue(i0 != s0) self.assertTrue(i0 != s0)
@ -1173,18 +1173,18 @@ def forward(self, x_1):
out.strip(), out.strip(),
"""\ """\
class f(torch.nn.Module): class f(torch.nn.Module):
def forward(self, a_1: "f32[s0, s1]", b_1: "f32[s2, s1]"): def forward(self, a_1: "f32[s75, s96]", b_1: "f32[s57, s96]"):
# No stacktrace found for following nodes # No stacktrace found for following nodes
sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(a_1, 0) sym_size_int: "Sym(s75)" = torch.ops.aten.sym_size.int(a_1, 0)
sym_size_int_1: "Sym(s2)" = torch.ops.aten.sym_size.int(b_1, 0) sym_size_int_1: "Sym(s57)" = torch.ops.aten.sym_size.int(b_1, 0)
add: "Sym(s0 + s2)" = sym_size_int + sym_size_int_1; sym_size_int = sym_size_int_1 = None add: "Sym(s57 + s75)" = sym_size_int + sym_size_int_1; sym_size_int = sym_size_int_1 = None
sym_size_int_2: "Sym(s1)" = torch.ops.aten.sym_size.int(a_1, 1) sym_size_int_2: "Sym(s96)" = torch.ops.aten.sym_size.int(a_1, 1)
sym_size_int_3: "Sym(s1)" = torch.ops.aten.sym_size.int(b_1, 1); b_1 = None sym_size_int_3: "Sym(s96)" = torch.ops.aten.sym_size.int(b_1, 1); b_1 = None
add_1: "Sym(2*s1)" = sym_size_int_2 + sym_size_int_3; sym_size_int_2 = sym_size_int_3 = None add_1: "Sym(2*s96)" = sym_size_int_2 + sym_size_int_3; sym_size_int_2 = sym_size_int_3 = None
new_empty: "f32[s0 + s2, 2*s1]" = torch.ops.aten.new_empty.default(a_1, [add, add_1], pin_memory = False); a_1 = add = add_1 = None new_empty: "f32[s57 + s75, 2*s96]" = torch.ops.aten.new_empty.default(a_1, [add, add_1], pin_memory = False); a_1 = add = add_1 = None
native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None
getitem: "f32[s0 + s2, 2*s1]" = native_dropout[0] getitem: "f32[s57 + s75, 2*s96]" = native_dropout[0]
getitem_1: "b8[s0 + s2, 2*s1]" = native_dropout[1]; native_dropout = None getitem_1: "b8[s57 + s75, 2*s96]" = native_dropout[1]; native_dropout = None
return (getitem, getitem_1)""", # noqa: B950 return (getitem, getitem_1)""", # noqa: B950
) )
@ -2846,8 +2846,8 @@ class TestGuardsExpressions(TestCase):
], ],
) )
self.assertEqual(f"{x.stride()}", "(s1, 1)") self.assertEqual(f"{x.stride()}", "(s49, 1)")
self.assertEqual(f"{x.shape}", "torch.Size([s0, s1])") self.assertEqual(f"{x.shape}", "torch.Size([s26, s49])")
x_clean = _remove_symbols_without_guarding(x, 4096) x_clean = _remove_symbols_without_guarding(x, 4096)

View File

@ -1084,7 +1084,7 @@ def forward(self, x_1, y_1):
test_inputs.append([(6, 8)]) test_inputs.append([(6, 8)])
gm = self._test_dynamic(f, [(3, 4)], test_inputs) gm = self._test_dynamic(f, [(3, 4)], test_inputs)
self.assertTrue(eval_guards(gm, torch.randn(4, 5))) self.assertTrue(eval_guards(gm, torch.randn(4, 5)))
self.assertEqual(repr(bind_symbols(gm, torch.randn(4, 5))), "{s0: 4, s1: 5}") self.assertEqual(repr(bind_symbols(gm, torch.randn(4, 5))), "{s75: 4, s96: 5}")
self.assertFalse(eval_guards(gm, torch.randn(25, 5))) self.assertFalse(eval_guards(gm, torch.randn(25, 5)))
self.assertExpectedInline(show_guards(gm), """L['x'].size()[0] <= 19""") self.assertExpectedInline(show_guards(gm), """L['x'].size()[0] <= 19""")
@ -1717,7 +1717,7 @@ def forward(self, a_1):
gm = self._test_dynamic(f, [(1, 6), (8, 1)], test_inputs) gm = self._test_dynamic(f, [(1, 6), (8, 1)], test_inputs)
self.assertTrue(eval_guards(gm, torch.randn(1, 10), torch.randn(6, 1))) self.assertTrue(eval_guards(gm, torch.randn(1, 10), torch.randn(6, 1)))
self.assertFalse(eval_guards(gm, torch.randn(1, 2), torch.randn(4, 1))) self.assertFalse(eval_guards(gm, torch.randn(1, 2), torch.randn(4, 1)))
self.assertExpectedInline(show_guards(gm), """2*L['a'].size()[1]*L['b'].size()[0] > 20""") self.assertExpectedInline(show_guards(gm), """2*L['b'].size()[0]*L['a'].size()[1] > 20""")
def test_new_empty(self): def test_new_empty(self):
def f(a, b): def f(a, b):

View File

@ -869,6 +869,12 @@ def _optimized_add(
if sortkey(lhs._args[0]) > sortkey(rhs._args[-1]): if sortkey(lhs._args[0]) > sortkey(rhs._args[-1]):
return make_optimized(rhs._args + lhs._args) return make_optimized(rhs._args + lhs._args)
# (a1+a3) + (a0+a2) => (a0+a1+a2+a3)
new_args = list(lhs._args)
for a in rhs._args:
new_args = _binary_search_insert_arg(new_args, a)
return make_optimized(new_args)
# (a0+a2) + a1 => (a0+a1+a2) # (a0+a2) + a1 => (a0+a1+a2)
if lhs_is_optimized_summation and rhs.is_symbol: if lhs_is_optimized_summation and rhs.is_symbol:
new_args = _binary_search_insert_arg(list(lhs._args), rhs) new_args = _binary_search_insert_arg(list(lhs._args), rhs)

View File

@ -14,6 +14,7 @@ import atexit
import collections import collections
import dis import dis
import functools import functools
import hashlib
import inspect import inspect
import itertools import itertools
import logging import logging
@ -3289,6 +3290,11 @@ class ShapeEnv:
self.guards: list[ShapeGuard] = [] self.guards: list[ShapeGuard] = []
self.axioms: dict[sympy.Expr, sympy.Expr] = {} self.axioms: dict[sympy.Expr, sympy.Expr] = {}
# A set of ids that have already been allocated. This is used
# for when we allocate symbol ids using the hash of the source
# names to ensure we don't have collisions via linear probing
self.unique_ids: set[int] = set()
# Maps symbolic ints to their original concrete values # Maps symbolic ints to their original concrete values
# Currently populated from tensors # Currently populated from tensors
self.var_to_val: dict[sympy.Symbol, sympy.Integer] = {} self.var_to_val: dict[sympy.Symbol, sympy.Integer] = {}
@ -4540,13 +4546,14 @@ class ShapeEnv:
# If we're not duck shaping, we always create a new symbol # If we're not duck shaping, we always create a new symbol
# Even if we're duck shaping, if we haven't seen this particular # Even if we're duck shaping, if we haven't seen this particular
# value before, we also create a new symbol # value before, we also create a new symbol
symbol_id = self._generate_unique_id(source.name())
if type(val) is int or is_nested_int(val): if type(val) is int or is_nested_int(val):
sympy_expr = make_symbol( sympy_expr = make_symbol(
SymT.SIZE, len(self.var_to_val), positive=positive, integer=True SymT.SIZE, symbol_id, positive=positive, integer=True
) )
else: else:
sympy_expr = make_symbol( sympy_expr = make_symbol(
SymT.FLOAT, len(self.var_to_val), positive=positive, real=True SymT.FLOAT, symbol_id, positive=positive, real=True
) )
self.source_to_var[source_name] = sympy_expr self.source_to_var[source_name] = sympy_expr
# We always associate vars to vals # We always associate vars to vals
@ -6558,6 +6565,13 @@ class ShapeEnv:
sloc, _ = self._get_stack_summary(framework_loc=framework_loc) sloc, _ = self._get_stack_summary(framework_loc=framework_loc)
return sloc return sloc
def _generate_unique_id(self, source_name: str) -> int:
attempt = int(hashlib.sha256(source_name.encode()).hexdigest(), 16) % 100
while attempt in self.unique_ids:
attempt += 1
self.unique_ids.add(attempt)
return attempt
def _find_frame_locals(self) -> _FrameLocalResult: def _find_frame_locals(self) -> _FrameLocalResult:
""" """
Given the current user code frame, finds the relevant lines of code, Given the current user code frame, finds the relevant lines of code,