From af7719a2fa0d039c995a7fd1268f627c512bc886 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 27 Mar 2025 16:02:27 +0000 Subject: [PATCH] Revert "Use source hashing to generate consistent symbolic ids (#149665)" This reverts commit 1f92348dc6c60e3020a723b37ecb8226cf2480c0. Reverted https://github.com/pytorch/pytorch/pull/149665 on behalf of https://github.com/malfet due to Broke trunk, see https://hud.pytorch.org/hud/pytorch/pytorch/6eb3c2e2822c50d8a87b43938a9cf7ef0561ede2/1?per_page=50&name_filter=linux-focal-cuda12.6-py3.10-gcc11-sm89&mergeLF=true ([comment](https://github.com/pytorch/pytorch/pull/149665#issuecomment-2758578187)) --- test/dynamo/test_aot_autograd_cache.py | 40 ---- test/dynamo/test_backward_higher_order_ops.py | 4 +- test/dynamo/test_comptime.py | 10 +- test/dynamo/test_exc.py | 64 ++--- test/dynamo/test_export.py | 10 +- test/dynamo/test_functions.py | 36 +-- test/dynamo/test_higher_order_ops.py | 60 ++--- test/dynamo/test_misc.py | 35 ++- test/dynamo/test_repros.py | 2 +- test/dynamo/test_structured_trace.py | 14 +- test/dynamo/test_subclasses.py | 224 +++++++++--------- test/export/test_export.py | 77 +++--- test/export/test_serialize.py | 8 +- test/functorch/test_aotdispatch.py | 4 +- test/functorch/test_control_flow.py | 188 +++++++-------- test/inductor/test_auto_functionalize.py | 74 +++--- test/inductor/test_torchinductor.py | 8 +- test/test_dynamic_shapes.py | 64 ++--- test/test_proxy_tensor.py | 4 +- torch/fx/experimental/sym_node.py | 6 - torch/fx/experimental/symbolic_shapes.py | 18 +- 21 files changed, 437 insertions(+), 513 deletions(-) diff --git a/test/dynamo/test_aot_autograd_cache.py b/test/dynamo/test_aot_autograd_cache.py index b43cbcdf067..d0188a04653 100644 --- a/test/dynamo/test_aot_autograd_cache.py +++ b/test/dynamo/test_aot_autograd_cache.py @@ -196,46 +196,6 @@ class AOTAutogradCacheTests(InductorTestCase): self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) - @inductor_config.patch("fx_graph_remote_cache", False) - @inductor_config.patch("fx_graph_cache", True) - @functorch_config.patch({"enable_autograd_cache": True}) - def test_symbol_specialization(self): - """ - Verify the symbol specializations don't cause cache miss. - """ - - def fn(x, y, z): - return (torch.randn(5) + x + y, z * torch.randn(1)) - - a = torch.rand(5) - torch._dynamo.maybe_mark_dynamic(a, 0) - b = torch.rand(5) - c = torch.randn(6) - torch._dynamo.maybe_mark_dynamic(c, 0) - - compiled_fn = torch.compile(fn, backend="inductor") - - # A first call should miss in the cache. - compiled_fn(a, b, c) - self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) - self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) - self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) - - # A second call should hit even if a new dimension is marked as dynamic - # that is later specialized as part of tracing. - a = torch.rand(5) - torch._dynamo.maybe_mark_dynamic(a, 0) - b = torch.rand(5) - torch._dynamo.maybe_mark_dynamic(b, 0) - c = torch.randn(6) - torch._dynamo.maybe_mark_dynamic(c, 0) - self._clear_dynamo_and_codecache() - - compiled_fn(a, b, c) - self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) - self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) - self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) - @functorch_config.patch({"enable_autograd_cache": True}) def test_aot_runtime_trace_joint(self): @torch.compile(backend="inductor") diff --git a/test/dynamo/test_backward_higher_order_ops.py b/test/dynamo/test_backward_higher_order_ops.py index 97577a7130a..5a446883fd7 100644 --- a/test/dynamo/test_backward_higher_order_ops.py +++ b/test/dynamo/test_backward_higher_order_ops.py @@ -245,7 +245,7 @@ class GraphModule(torch.nn.Module): actual, """\ class GraphModule(torch.nn.Module): - def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(2)", L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s45)"): + def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(2)", L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s7)"): l_inputs_ = L_inputs_ l_sizes_0_ = L_sizes_0_ l_hooks_1_keywords_fn_keywords_obj_counter = L_hooks_1_keywords_fn_keywords_obj_counter @@ -264,7 +264,7 @@ class GraphModule(torch.nn.Module): copy_: "f32[2]" = new_grad_strided.copy_(aot0_tangents_1); copy_ = None - add: "Sym(s45 + 1)" = l_hooks_1_keywords_fn_keywords_obj_counter + 1; l_hooks_1_keywords_fn_keywords_obj_counter = None + add: "Sym(s7 + 1)" = l_hooks_1_keywords_fn_keywords_obj_counter + 1; l_hooks_1_keywords_fn_keywords_obj_counter = None result: "f32[2]" = aot0_tangents_1 * aot0_tangents_1; aot0_tangents_1 = None diff --git a/test/dynamo/test_comptime.py b/test/dynamo/test_comptime.py index 4edd56c2d9c..1f79c04c624 100644 --- a/test/dynamo/test_comptime.py +++ b/test/dynamo/test_comptime.py @@ -57,18 +57,18 @@ class ComptimeTests(torch._dynamo.test_case.TestCase): self.assertExpectedInline( FILE.getvalue().strip(), """\ -FakeTensor(..., size=(s77,)) +FakeTensor(..., size=(s0,)) 2 -[FakeTensor(..., size=(s77,)), 2] -(FakeTensor(..., size=(s77,)), 2) -{'foo': FakeTensor(..., size=(s77,))} +[FakeTensor(..., size=(s0,)), 2] +(FakeTensor(..., size=(s0,)), 2) +{'foo': FakeTensor(..., size=(s0,))} range(1, 3, 1) Employee(name='foo', id=2) UserDefinedListVariable(mylist) defaultdict(NestedUserFunctionVariable(), {}) set() {'a','b'} -s77""", +s0""", ) def test_print_graph(self): diff --git a/test/dynamo/test_exc.py b/test/dynamo/test_exc.py index acc3fd55f6f..2d0288b8da0 100644 --- a/test/dynamo/test_exc.py +++ b/test/dynamo/test_exc.py @@ -256,34 +256,34 @@ Model: ==> L['x'].size()[0]: 3 ==> L['x'].storage_offset(): 0 ==> L['x'].stride()[0]: 1 + ==> s0: 3 + ==> s1: 0 + ==> s2: 1 ==> s3: 1 - ==> s52: 1 - ==> s77: 3 - ==> s86: 0 Assertions: ==> (== 0 L['x'].storage_offset()) ==> (== 1 L['x'].stride()[0]) - ==> (== L['shape'][0] s86) - ==> (== L['shape'][1] s52) + ==> (== L['shape'][0] s1) + ==> (== L['shape'][1] s2) ==> (== L['shape'][2] s3) - ==> (== L['x'].size()[0] s77) - ==> (> s77 1) + ==> (== L['x'].size()[0] s0) + ==> (> s0 1) Target Expressions: - ==> (!= (+ s3 s52 s86) s77) + ==> (!= (+ s1 s2 s3) s0) + ==> (<= 0 s1) + ==> (<= 0 s2) ==> (<= 0 s3) - ==> (<= 0 s52) - ==> (<= 0 s86) - ==> (<= 2 s77) + ==> (<= 2 s0) ==> (== 0 L['x'].storage_offset()) ==> (== 1 L['x'].stride()[0]) - ==> (== L['shape'][0] s86) - ==> (== L['shape'][1] s52) + ==> (== L['shape'][0] s1) + ==> (== L['shape'][1] s2) ==> (== L['shape'][2] s3) - ==> (== L['x'].size()[0] s77) - ==> (> s77 0) - ==> (>= 0 s86) + ==> (== L['x'].size()[0] s0) + ==> (> s0 0) + ==> (>= 0 s1) Failed Source Expressions: ==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""", @@ -309,7 +309,7 @@ Failed Source Expressions: BisectValidationException, lambda: fn(torch.randn(20), (5, 10, 5)), """\ -translation validation failed when evaluating: Eq(s3 + s52 + s86, s77) +translation validation failed when evaluating: Eq(s1 + s2 + s3, s0) Failure occurred while running node: %split : [num_users=3] = call_method[target=split](args = (%l_x_, (%l_shape_0_, %l_shape_1_, %l_shape_2_)), kwargs = {}) @@ -321,33 +321,33 @@ Model: ==> L['x'].size()[0]: 3 ==> L['x'].storage_offset(): 0 ==> L['x'].stride()[0]: 1 + ==> s0: 3 + ==> s1: 1 + ==> s2: 1 ==> s3: 0 - ==> s52: 1 - ==> s77: 3 - ==> s86: 1 Assertions: ==> (== 0 L['x'].storage_offset()) ==> (== 1 L['x'].stride()[0]) - ==> (== L['shape'][0] s86) - ==> (== L['shape'][1] s52) + ==> (== L['shape'][0] s1) + ==> (== L['shape'][1] s2) ==> (== L['shape'][2] s3) - ==> (== L['x'].size()[0] s77) - ==> (> s77 1) + ==> (== L['x'].size()[0] s0) + ==> (> s0 1) Target Expressions: - ==> (!= (+ s3 s52 s86) s77) + ==> (!= (+ s1 s2 s3) s0) + ==> (<= 0 s1) + ==> (<= 0 s2) ==> (<= 0 s3) - ==> (<= 0 s52) - ==> (<= 0 s86) - ==> (<= 2 s77) + ==> (<= 2 s0) ==> (== 0 L['x'].storage_offset()) ==> (== 1 L['x'].stride()[0]) - ==> (== L['shape'][0] s86) - ==> (== L['shape'][1] s52) + ==> (== L['shape'][0] s1) + ==> (== L['shape'][1] s2) ==> (== L['shape'][2] s3) - ==> (== L['x'].size()[0] s77) - ==> (> s77 0) + ==> (== L['x'].size()[0] s0) + ==> (> s0 0) Failed Source Expressions: ==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""", diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 33ee99e2e04..5080f7f8271 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -2703,7 +2703,7 @@ def forward(self, x): for node in ebar.graph_module.graph.nodes if node.op == "placeholder" ], - ["torch.Size([s17, s27, s27])", "torch.Size([s17, s27, s27])"], + ["torch.Size([s0, s1, s1])", "torch.Size([s0, s1, s1])"], ) @torch._dynamo.config.patch( @@ -3480,23 +3480,23 @@ def forward(self, x): true_graph = """\ class GraphModule(torch.nn.Module): def forward(self, pred, x): - arg1: "f32[s77, s27]"; + arg1: "f32[s1, s2]"; arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec) l_x_ = arg1 - sin: "f32[s77, s27]" = l_x_.sin(); l_x_ = None + sin: "f32[s1, s2]" = l_x_.sin(); l_x_ = None return pytree.tree_unflatten([sin], self._out_spec) """ false_graph = """\ class GraphModule(torch.nn.Module): def forward(self, pred, x): - arg1: "f32[s77, s27]"; + arg1: "f32[s1, s2]"; arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec) l_x_ = arg1 - cos: "f32[s77, s27]" = l_x_.cos(); l_x_ = None + cos: "f32[s1, s2]" = l_x_.cos(); l_x_ = None return pytree.tree_unflatten([cos], self._out_spec) """ true_guard_code = [ diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 09aee481c0c..e8c295e8981 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -2655,7 +2655,7 @@ class GraphModule(torch.nn.Module): normalize_gm(backend.graphs[0].print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): - def forward(self, s77: "Sym(s77)", L_x_: "f32[s77]"): + def forward(self, s0: "Sym(s0)", L_x_: "f32[s0]"): l_x_ = L_x_ sum_1: "f32[]" = l_x_.sum(); l_x_ = None @@ -2885,13 +2885,13 @@ class GraphModule(torch.nn.Module): normalize_gm(backend.graphs[0].print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): - def forward(self, s9: "Sym(s9)", L_lambda0_keywords_y_: "f32[s9, s9]"): + def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"): l_lambda0_keywords_y_ = L_lambda0_keywords_y_ - mul: "f32[s9, s9]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_ - mul_1: "f32[s9, s9]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None + mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_ + mul_1: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None - mul_2: "f32[s9, s9]" = torch.mul(mul, mul_1); mul = mul_1 = None + mul_2: "f32[s0, s0]" = torch.mul(mul, mul_1); mul = mul_1 = None return (mul_2,) """, ) @@ -2932,14 +2932,14 @@ class GraphModule(torch.nn.Module): normalize_gm(backend.graphs[0].print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): - def forward(self, s9: "Sym(s9)", L_lambda0_keywords_y_: "f32[s9, s9]"): + def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"): l_lambda0_keywords_y_ = L_lambda0_keywords_y_ - mul: "f32[s9, s9]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_ + mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_ - add: "f32[s9, s9]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None + add: "f32[s0, s0]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None - mul_1: "f32[s9, s9]" = torch.mul(mul, add); mul = add = None + mul_1: "f32[s0, s0]" = torch.mul(mul, add); mul = add = None return (mul_1,) """, ) @@ -2982,14 +2982,14 @@ class GraphModule(torch.nn.Module): normalize_gm(backend.graphs[0].print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): - def forward(self, s9: "Sym(s9)", L_lambda0_keywords_y_: "f32[s9, s9]"): + def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"): l_lambda0_keywords_y_ = L_lambda0_keywords_y_ - mul: "f32[s9, s9]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_ + mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_ - add: "f32[s9, s9]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None + add: "f32[s0, s0]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None - mul_1: "f32[s9, s9]" = torch.mul(mul, add); mul = add = None + mul_1: "f32[s0, s0]" = torch.mul(mul, add); mul = add = None return (mul_1,) """, ) @@ -3029,14 +3029,14 @@ class GraphModule(torch.nn.Module): normalize_gm(backend.graphs[0].print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): - def forward(self, s77: "Sym(s77)", L_x_: "f32[s77, s77]"): + def forward(self, s0: "Sym(s0)", L_x_: "f32[s0, s0]"): l_x_ = L_x_ - mul: "f32[s77, s77]" = l_x_ * 4 - mul_1: "f32[s77, s77]" = mul * l_x_; mul = None - mul_2: "f32[s77, s77]" = 20 * l_x_; l_x_ = None + mul: "f32[s0, s0]" = l_x_ * 4 + mul_1: "f32[s0, s0]" = mul * l_x_; mul = None + mul_2: "f32[s0, s0]" = 20 * l_x_; l_x_ = None - mul_3: "f32[s77, s77]" = torch.mul(mul_1, mul_2); mul_1 = mul_2 = None + mul_3: "f32[s0, s0]" = torch.mul(mul_1, mul_2); mul_1 = mul_2 = None return (mul_3,) """, ) diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index dfbad8422b8..a04beddb5c1 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -413,18 +413,18 @@ class GraphModule(torch.nn.Module): actual_graph, """\ class GraphModule(torch.nn.Module): - def forward(self, s77: "Sym(s77)", L_x_: "f32[s77, 1]"): + def forward(self, s0: "Sym(s0)", L_x_: "f32[s0, 1]"): l_x_ = L_x_ wrap_body_0 = self.wrap_body_0 - wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, l_x_); wrap_body_0 = s77 = l_x_ = None - getitem: "f32[s77]" = wrap[0]; wrap = None + wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, l_x_); wrap_body_0 = s0 = l_x_ = None + getitem: "f32[s0]" = wrap[0]; wrap = None return (getitem,) class wrap_body_0(torch.nn.Module): - def forward(self, s77: "Sym(s77)", l_x_: "f32[s77, 1]"): - view: "f32[s77]" = l_x_.view(s77); l_x_ = s77 = None - add: "f32[s77]" = view + 0.5; view = None + def forward(self, s0: "Sym(s0)", l_x_: "f32[s0, 1]"): + view: "f32[s0]" = l_x_.view(s0); l_x_ = s0 = None + add: "f32[s0]" = view + 0.5; view = None return (add,) """, ) @@ -606,27 +606,27 @@ class GraphModule(torch.nn.Module): out_graph, """\ class GraphModule(torch.nn.Module): - def forward(self, s77: "Sym(s77)", L_x_: "f32[s77]"): + def forward(self, s0: "Sym(s0)", L_x_: "f32[s0]"): l_x_ = L_x_ sum_1: "f32[]" = l_x_.sum() item: "Sym(zuf0)" = sum_1.item(); sum_1 = None wrap_body_1 = self.wrap_body_1 - wrap = torch.ops.higher_order.wrap(wrap_body_1, s77, l_x_, item); wrap_body_1 = s77 = l_x_ = item = None - getitem: "f32[s77]" = wrap[0]; wrap = None + wrap = torch.ops.higher_order.wrap(wrap_body_1, s0, l_x_, item); wrap_body_1 = s0 = l_x_ = item = None + getitem: "f32[s0]" = wrap[0]; wrap = None return (getitem,) class wrap_body_1(torch.nn.Module): - def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", item: "Sym(zuf0)"): + def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", item: "Sym(zuf0)"): wrap_body_0 = self.wrap_body_0 - wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, l_x_, item); wrap_body_0 = s77 = l_x_ = item = None - getitem: "f32[s77]" = wrap[0]; wrap = None + wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, l_x_, item); wrap_body_0 = s0 = l_x_ = item = None + getitem: "f32[s0]" = wrap[0]; wrap = None return (getitem,) class wrap_body_0(torch.nn.Module): - def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", item: "Sym(zuf0)"): - add: "f32[s77]" = l_x_ + item; l_x_ = item = None + def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", item: "Sym(zuf0)"): + add: "f32[s0]" = l_x_ + item; l_x_ = item = None return (add,) """, ) @@ -692,7 +692,7 @@ class GraphModule(torch.nn.Module): out_graph, """\ class GraphModule(torch.nn.Module): - def forward(self, s77: "Sym(s77)", L_x_: "f32[s77]"): + def forward(self, s0: "Sym(s0)", L_x_: "f32[s0]"): l_x_ = L_x_ c: "i64[u0, 1]" = l_x_.nonzero() @@ -704,22 +704,22 @@ class GraphModule(torch.nn.Module): _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None wrap_body_1 = self.wrap_body_1 - wrap = torch.ops.higher_order.wrap(wrap_body_1, s77, l_x_, sym_size_int_1, c); wrap_body_1 = s77 = l_x_ = sym_size_int_1 = c = None - getitem: "f32[s77]" = wrap[0] + wrap = torch.ops.higher_order.wrap(wrap_body_1, s0, l_x_, sym_size_int_1, c); wrap_body_1 = s0 = l_x_ = sym_size_int_1 = c = None + getitem: "f32[s0]" = wrap[0] getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None return (getitem, getitem_1) class wrap_body_1(torch.nn.Module): - def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", u0: "Sym(u0)", c: "i64[u0, 1]"): + def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", u0: "Sym(u0)", c: "i64[u0, 1]"): wrap_body_0 = self.wrap_body_0 - wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, l_x_, u0, c); wrap_body_0 = s77 = l_x_ = u0 = c = None - child: "f32[s77]" = wrap[0] + wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, l_x_, u0, c); wrap_body_0 = s0 = l_x_ = u0 = c = None + child: "f32[s0]" = wrap[0] child_1: "f32[u0, 1]" = wrap[1]; wrap = None return (child, child_1) class wrap_body_0(torch.nn.Module): - def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", u0: "Sym(u0)", c: "i64[u0, 1]"): - child: "f32[s77]" = l_x_.sin(); l_x_ = None + def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", u0: "Sym(u0)", c: "i64[u0, 1]"): + child: "f32[s0]" = l_x_.sin(); l_x_ = None child_1: "f32[u0, 1]" = c.sin(); c = None return (child, child_1) """, @@ -994,25 +994,25 @@ class GraphModule(torch.nn.Module): out_graph, """\ class GraphModule(torch.nn.Module): - def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]", s94: "Sym(s94)", L_y_: "f32[s27, s94]"): + def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_x_: "f32[s0, s1]", s2: "Sym(s2)", L_y_: "f32[s1, s2]"): l_x_ = L_x_ l_y_ = L_y_ wrap_body_1 = self.wrap_body_1 - wrap = torch.ops.higher_order.wrap(wrap_body_1, s77, s27, l_x_, s94, l_y_); wrap_body_1 = s77 = s27 = l_x_ = s94 = l_y_ = None - getitem: "f32[s77, s94]" = wrap[0]; wrap = None + wrap = torch.ops.higher_order.wrap(wrap_body_1, s0, s1, l_x_, s2, l_y_); wrap_body_1 = s0 = s1 = l_x_ = s2 = l_y_ = None + getitem: "f32[s0, s2]" = wrap[0]; wrap = None return (getitem,) class wrap_body_1(torch.nn.Module): - def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", l_x_: "f32[s77, s27]", s94: "Sym(s94)", l_y_: "f32[s27, s94]"): + def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", l_x_: "f32[s0, s1]", s2: "Sym(s2)", l_y_: "f32[s1, s2]"): wrap_body_0 = self.wrap_body_0 - wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, s27, l_x_, s94, l_y_); wrap_body_0 = s77 = s27 = l_x_ = s94 = l_y_ = None - getitem: "f32[s77, s94]" = wrap[0]; wrap = None + wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, s1, l_x_, s2, l_y_); wrap_body_0 = s0 = s1 = l_x_ = s2 = l_y_ = None + getitem: "f32[s0, s2]" = wrap[0]; wrap = None return (getitem,) class wrap_body_0(torch.nn.Module): - def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", l_x_: "f32[s77, s27]", s94: "Sym(s94)", l_y_: "f32[s27, s94]"): - matmul: "f32[s77, s94]" = l_x_ @ l_y_; l_x_ = l_y_ = None + def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", l_x_: "f32[s0, s1]", s2: "Sym(s2)", l_y_: "f32[s1, s2]"): + matmul: "f32[s0, s2]" = l_x_ @ l_y_; l_x_ = l_y_ = None return (matmul,) """, ) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index a6b3a29eb4d..c74d6039dd4 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -10382,22 +10382,19 @@ ShapeEnv not equal: field values don't match: > Left: {x.size()[0]: x.size()[0], x.size()[1]: x.size()[1], x.storage_offset(): x.storage_offset(), x.stride()[0]: x.stride()[0], x.stride()[1]: x.stride()[1]} > Right: {} ==> source_to_var: values don't match. - > Left: {x.size()[0]: s93, x.size()[1]: s44} - > Right: {} -==> unique_ids: values don't match. - > Left: {44, 93} + > Left: {x.size()[0]: s0, x.size()[1]: s1} > Right: {} ==> val_to_var: values don't match. - > Left: {0: 0, 1: 1, 2: s44, 3: s93} + > Left: {0: 0, 1: 1, 2: s1, 3: s0} > Right: {0: 0, 1: 1} ==> var_to_range: values don't match. - > Left: {s44: VR[2, int_oo], s93: VR[2, int_oo]} + > Left: {s0: VR[2, int_oo], s1: VR[2, int_oo]} > Right: {} ==> var_to_sources: values don't match. - > Left: {s44: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=, idx=1)], s93: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=, idx=0)]} + > Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=, idx=1)]} > Right: {} ==> var_to_val: values don't match. - > Left: {s44: 2, s93: 3} + > Left: {s0: 3, s1: 2} > Right: {} """, ) @@ -10456,13 +10453,13 @@ ShapeEnv not equal: field values don't match: ShapeEnv not equal: field values don't match: ==> axioms: values don't match. - > Left: {(Mod(s93, 3)) < 0: False, (Mod(s93, 3)) <= 0: True, 0 < (Mod(s93, 3)): False, 0 <= (Mod(s93, 3)): True, Eq(0, Mod(s93, 3)): True, Eq(Mod(s93, 3), 0): True, Ne(0, Mod(s93, 3)): False, Ne(Mod(s93, 3), 0): False} + > Left: {(Mod(s0, 3)) < 0: False, (Mod(s0, 3)) <= 0: True, 0 < (Mod(s0, 3)): False, 0 <= (Mod(s0, 3)): True, Eq(0, Mod(s0, 3)): True, Eq(Mod(s0, 3), 0): True, Ne(0, Mod(s0, 3)): False, Ne(Mod(s0, 3), 0): False} > Right: {} ==> divisible: values don't match. - > Left: {Mod(s93, 3)} + > Left: {Mod(s0, 3)} > Right: {} ==> guards: values don't match. - > Left: [Eq(Mod(s93, 3), 0)] + > Left: [Eq(Mod(s0, 3), 0)] > Right: [] ==> name_to_node: values don't match. > Left: {_assert, eq, mod, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} @@ -10499,17 +10496,17 @@ ShapeEnv not equal: field values don't match: > Left: {False: False, True: True} > Right: {} ==> guards: values don't match. - > Left: [Eq(s93, 3)] + > Left: [Eq(s0, 3)] > Right: [] ==> name_to_node: values don't match. > Left: {_assert, eq, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} ==> replacements: values don't match. - > Left: {s93: 3} + > Left: {s0: 3} > Right: {} ==> var_to_range: values don't match. - > Left: {s44: VR[2, int_oo], s93: VR[3, 3]} - > Right: {s44: VR[2, int_oo], s93: VR[2, int_oo]} + > Left: {s0: VR[3, 3], s1: VR[2, int_oo]} + > Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]} """, ) self._replay_and_check(main) @@ -10540,17 +10537,17 @@ ShapeEnv not equal: field values don't match: ShapeEnv not equal: field values don't match: ==> axioms: values don't match. - > Left: {3 <= s93: True, s93 < 3: False} + > Left: {3 <= s0: True, s0 < 3: False} > Right: {} ==> guards: values don't match. - > Left: [s93 >= 3] + > Left: [s0 >= 3] > Right: [] ==> name_to_node: values don't match. > Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} ==> var_to_range: values don't match. - > Left: {s44: VR[2, int_oo], s93: VR[3, int_oo]} - > Right: {s44: VR[2, int_oo], s93: VR[2, int_oo]} + > Left: {s0: VR[3, int_oo], s1: VR[2, int_oo]} + > Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]} """, ) self._replay_and_check(main) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 6d8c86923cf..6bbe9127212 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -4768,7 +4768,7 @@ class ReproTests(torch._dynamo.test_case.TestCase): self.assertExpectedInline( str(graph.code).strip(), """\ -def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor): +def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor): l_x_ = L_x_ getitem_2 = l_x_[0] sum_1 = getitem_2.sum(); getitem_2 = None diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index a1004885161..860c5bddaba 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -339,7 +339,7 @@ class StructuredTraceTest(TestCase): {"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} -{"create_symbol": {"symbol": "s48", "val": "1", "vr": "[-int_oo, int_oo]", "source": "L['y']", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"create_symbol": {"symbol": "s0", "val": "1", "vr": "[-int_oo, int_oo]", "source": "L['y']", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_x_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"inductor_pre_grad_graph": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_recompile_pre_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} @@ -767,15 +767,15 @@ class StructuredTraceTest(TestCase): {"describe_storage": {"id": 0, "describer_id": "ID", "size": 200}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [5, 10], "is_leaf": true, "stride": [10, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} -{"create_symbol": {"symbol": "s97", "val": "5", "vr": "[2, int_oo]", "source": "L['a'].size()[0]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} -{"create_symbol": {"symbol": "s98", "val": "10", "vr": "[2, int_oo]", "source": "L['a'].size()[1]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"create_symbol": {"symbol": "s0", "val": "5", "vr": "[2, int_oo]", "source": "L['a'].size()[0]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"create_symbol": {"symbol": "s1", "val": "10", "vr": "[2, int_oo]", "source": "L['a'].size()[1]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"describe_storage": {"id": 1, "describer_id": "ID", "size": 600}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"describe_tensor": {"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [10, 15], "is_leaf": true, "stride": [15, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 1, "source": "L['b']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} -{"create_symbol": {"symbol": "s52", "val": "10", "vr": "[2, int_oo]", "source": "L['b'].size()[0]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} -{"create_symbol": {"symbol": "s20", "val": "15", "vr": "[2, int_oo]", "source": "L['b'].size()[1]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} -{"guard_added_fast": {"expr": "Eq(s98, s52)", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} -{"dynamo_output_graph": {"sizes": {"l_a_": ["s97", "s52"], "l_b_": ["s52", "s20"], "matmul": ["s97", "s20"]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"create_symbol": {"symbol": "s2", "val": "10", "vr": "[2, int_oo]", "source": "L['b'].size()[0]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"create_symbol": {"symbol": "s3", "val": "15", "vr": "[2, int_oo]", "source": "L['b'].size()[1]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"guard_added_fast": {"expr": "Eq(s1, s2)", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_a_": ["s0", "s1"], "l_b_": ["s1", "s3"], "matmul": ["s0", "s3"]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 1, "attempt": 0} """, # noqa: B950 diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index ef2acadac89..d4ef79b0723 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -1373,21 +1373,21 @@ class GraphModule(torch.nn.Module): # During fakeifying, we end up allocating a separate symint # for the outer and inner tensor (in this test, s0 is unused). expected_var_to_val = { - "s50": 4, - "s77": 8, + "s0": 8, + "s1": 4, } expected_var_to_sources = { - "s50": "L['x'].inner_elem.size()[0]", - "s77": "L['x'].size()[0]", + "s0": "L['x'].size()[0]", + "s1": "L['x'].inner_elem.size()[0]", } self.assertEqual(curr_var_to_val, expected_var_to_val) self.assertEqual(curr_var_to_sources, expected_var_to_sources) self.assertExpectedInline( "\n".join(guards), """\ -Eq(2*s50, s77) -2*s50 < 13 -s50 > 3""", +Eq(2*s1, s0) +2*s1 < 13 +s1 > 3""", ) def test_wrapper_subclass_with_same_sized_inner_tensor(self): @@ -1976,9 +1976,9 @@ class GraphModule(torch.nn.Module): normalize_gm(fw[0].print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): - def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"): - mul: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None - mul_3: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None + def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"): + mul: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None + mul_3: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None return (mul, mul_3, primals_5, primals_7, primals_7, primals_1, primals_5, primals_7) """, # noqa: B950 ) @@ -1987,9 +1987,9 @@ class GraphModule(torch.nn.Module): normalize_gm(bw[0].print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): - def forward(self, primals_1: "Sym(s47)", primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s47, s16]", tangents_2: "f32[s47, s16]"): - mul_8: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = None - mul_9: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(tangents_2, primals_1); tangents_2 = primals_1 = None + def forward(self, primals_1: "Sym(s0)", primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0, s1]", tangents_2: "f32[s0, s1]"): + mul_8: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = None + mul_9: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_2, primals_1); tangents_2 = primals_1 = None return (None, None, mul_8, mul_9, primals_5, primals_7, primals_7) """, # noqa: B950 ) @@ -2009,12 +2009,12 @@ class GraphModule(torch.nn.Module): normalize_gm(fw[0].print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): - def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"): - clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None - clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None + def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"): + clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_3); primals_3 = None + clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None - view: "f32[s16, s47]" = torch.ops.aten.view.default(clone, [primals_2, primals_1]); clone = None - view_1: "f32[s16, s47]" = torch.ops.aten.view.default(clone_1, [primals_2, primals_1]); clone_1 = primals_1 = None + view: "f32[s1, s0]" = torch.ops.aten.view.default(clone, [primals_2, primals_1]); clone = None + view_1: "f32[s1, s0]" = torch.ops.aten.view.default(clone_1, [primals_2, primals_1]); clone_1 = primals_1 = None return (view, view_1, primals_2, primals_5, primals_5, primals_5, primals_7) """, # noqa: B950 ) @@ -2023,9 +2023,9 @@ class GraphModule(torch.nn.Module): normalize_gm(bw[0].print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): - def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s16, s47]", tangents_2: "f32[s16, s47]"): - view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None - view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None + def forward(self, primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s1, s0]", tangents_2: "f32[s1, s0]"): + view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None + view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None return (None, None, view_2, view_3, primals_5, primals_7, primals_7) """, # noqa: B950 ) @@ -2047,15 +2047,15 @@ class GraphModule(torch.nn.Module): normalize_gm(fw[0].print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): - def forward(self, primals_1: "Sym(s97)", primals_2: "Sym(s98)", primals_3: "f32[s97, s98]", primals_4: "f32[s97, s98]", primals_5: "Sym(s97)", primals_6: "Sym(s98)", primals_7: "Sym(s98)"): - mul: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None - mul_3: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None - mul_8: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul, primals_2); mul = None - mul_11: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_3, primals_2); mul_3 = None - mul_16: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_8, primals_1); mul_8 = None - mul_19: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_11, primals_1); mul_11 = None - mul_24: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_16, primals_2); mul_16 = None - mul_27: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_19, primals_2); mul_19 = None + def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"): + mul: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None + mul_3: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None + mul_8: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul, primals_2); mul = None + mul_11: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_3, primals_2); mul_3 = None + mul_16: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_8, primals_1); mul_8 = None + mul_19: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_11, primals_1); mul_11 = None + mul_24: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_16, primals_2); mul_16 = None + mul_27: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_19, primals_2); mul_19 = None return (mul_24, mul_27, primals_5, primals_7, primals_7, primals_1, primals_2, primals_5, primals_7) """, # noqa: B950 ) @@ -2064,15 +2064,15 @@ class GraphModule(torch.nn.Module): normalize_gm(bw[0].print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): - def forward(self, primals_1: "Sym(s97)", primals_2: "Sym(s98)", primals_5: "Sym(s97)", primals_7: "Sym(s98)", tangents_1: "f32[s97, s98]", tangents_2: "f32[s97, s98]"): - mul_32: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(tangents_1, primals_2); tangents_1 = None - mul_33: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(tangents_2, primals_2); tangents_2 = None - mul_34: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_32, primals_1); mul_32 = None - mul_35: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_33, primals_1); mul_33 = None - mul_36: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_34, primals_2); mul_34 = None - mul_37: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_35, primals_2); mul_35 = primals_2 = None - mul_38: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_36, primals_1); mul_36 = None - mul_39: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_37, primals_1); mul_37 = primals_1 = None + def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0, s1]", tangents_2: "f32[s0, s1]"): + mul_32: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_1, primals_2); tangents_1 = None + mul_33: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_2, primals_2); tangents_2 = None + mul_34: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_32, primals_1); mul_32 = None + mul_35: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_33, primals_1); mul_33 = None + mul_36: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_34, primals_2); mul_34 = None + mul_37: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_35, primals_2); mul_35 = primals_2 = None + mul_38: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_36, primals_1); mul_36 = None + mul_39: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(mul_37, primals_1); mul_37 = primals_1 = None return (None, None, mul_38, mul_39, primals_5, primals_7, primals_7) """, # noqa: B950 ) @@ -2092,12 +2092,12 @@ class GraphModule(torch.nn.Module): normalize_gm(fw[0].print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): - def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"): - clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None - clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None + def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"): + clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_3); primals_3 = None + clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None - view: "f32[s47, s16]" = torch.ops.aten.view.default(clone, [primals_1, primals_2]); clone = None - view_1: "f32[s47, s16]" = torch.ops.aten.view.default(clone_1, [primals_1, primals_2]); clone_1 = primals_1 = primals_2 = None + view: "f32[s0, s1]" = torch.ops.aten.view.default(clone, [primals_1, primals_2]); clone = None + view_1: "f32[s0, s1]" = torch.ops.aten.view.default(clone_1, [primals_1, primals_2]); clone_1 = primals_1 = primals_2 = None return (view, view_1, primals_5, primals_7, primals_7, primals_5, primals_7) """, # noqa: B950 ) @@ -2106,9 +2106,9 @@ class GraphModule(torch.nn.Module): normalize_gm(bw[0].print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): - def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s47, s16]", tangents_2: "f32[s47, s16]"): - view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None - view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None + def forward(self, primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0, s1]", tangents_2: "f32[s0, s1]"): + view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None + view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None return (None, None, view_2, view_3, primals_5, primals_7, primals_7) """, # noqa: B950 ) @@ -2128,13 +2128,13 @@ class GraphModule(torch.nn.Module): normalize_gm(fw[0].print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): - def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"): - clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None - clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None + def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"): + clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_3); primals_3 = None + clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None - mul_6: "Sym(s16*s47)" = primals_1 * primals_2; primals_1 = primals_2 = None - view: "f32[s16*s47]" = torch.ops.aten.view.default(clone, [mul_6]); clone = None - view_1: "f32[s16*s47]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None + mul_6: "Sym(s0*s1)" = primals_1 * primals_2; primals_1 = primals_2 = None + view: "f32[s0*s1]" = torch.ops.aten.view.default(clone, [mul_6]); clone = None + view_1: "f32[s0*s1]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None return (view, view_1, mul_6, primals_5, primals_7) """, # noqa: B950 ) @@ -2143,9 +2143,9 @@ class GraphModule(torch.nn.Module): normalize_gm(bw[0].print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): - def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s16*s47]", tangents_2: "f32[s16*s47]"): - view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None - view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None + def forward(self, primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0*s1]", tangents_2: "f32[s0*s1]"): + view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None + view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None return (None, None, view_2, view_3, primals_5, primals_7, primals_7) """, # noqa: B950 ) @@ -2165,13 +2165,13 @@ class GraphModule(torch.nn.Module): normalize_gm(fw[0].print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): - def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"): - clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None - clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None + def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"): + clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_3); primals_3 = None + clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None - mul_6: "Sym(s16*s47)" = primals_1 * primals_2; primals_1 = primals_2 = None - view: "f32[s16*s47]" = torch.ops.aten.view.default(clone, [mul_6]) - view_1: "f32[s16*s47]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None + mul_6: "Sym(s0*s1)" = primals_1 * primals_2; primals_1 = primals_2 = None + view: "f32[s0*s1]" = torch.ops.aten.view.default(clone, [mul_6]) + view_1: "f32[s0*s1]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None return (clone, view, view_1, mul_6, primals_5, primals_7) """, # noqa: B950 ) @@ -2180,9 +2180,9 @@ class GraphModule(torch.nn.Module): normalize_gm(bw[0].print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): - def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s16*s47]", tangents_2: "f32[s16*s47]"): - view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None - view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None + def forward(self, primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0*s1]", tangents_2: "f32[s0*s1]"): + view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None + view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None return (None, None, view_2, view_3, primals_5, primals_7, primals_7) """, # noqa: B950 ) @@ -2261,13 +2261,13 @@ class GraphModule(torch.nn.Module): normalize_gm(fw[1].print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): - def forward(self, primals_1: "Sym(s16)", primals_2: "f32[3, s16]", primals_3: "f32[3, s16]", primals_4: "Sym(s16)", primals_5: "Sym(s16)"): - clone: "f32[3, s16]" = torch.ops.aten.clone.default(primals_2); primals_2 = None - clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None + def forward(self, primals_1: "Sym(s0)", primals_2: "f32[3, s0]", primals_3: "f32[3, s0]", primals_4: "Sym(s0)", primals_5: "Sym(s0)"): + clone: "f32[3, s0]" = torch.ops.aten.clone.default(primals_2); primals_2 = None + clone_1: "f32[3, s0]" = torch.ops.aten.clone.default(primals_3); primals_3 = None - view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1]) - sym_numel_default: "Sym(3*s16)" = torch.ops.aten.sym_numel.default(clone) - view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1]) + view: "f32[3*s0]" = torch.ops.aten.view.default(clone, [-1]) + sym_numel_default: "Sym(3*s0)" = torch.ops.aten.sym_numel.default(clone) + view_1: "f32[3*s0]" = torch.ops.aten.view.default(clone_1, [-1]) return (clone, view, view_1, sym_numel_default, clone_1, primals_5) """, # noqa: B950 ) @@ -2287,9 +2287,9 @@ class GraphModule(torch.nn.Module): normalize_gm(bw[1].print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): - def forward(self, primals_5: "Sym(s16)", tangents_1: "f32[3*s16]", tangents_2: "f32[3*s16]"): - view_2: "f32[3, s16]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None - view_3: "f32[3, s16]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None + def forward(self, primals_5: "Sym(s0)", tangents_1: "f32[3*s0]", tangents_2: "f32[3*s0]"): + view_2: "f32[3, s0]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None + view_3: "f32[3, s0]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None return (None, view_2, view_3, primals_5, primals_5) """, # noqa: B950 ) @@ -2317,13 +2317,13 @@ class GraphModule(torch.nn.Module): normalize_gm(fw[0].print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): - def forward(self, primals_1: "Sym(s16)", primals_2: "f32[3, s16]", primals_3: "f32[3, s16]", primals_4: "Sym(s16)", primals_5: "Sym(s16)"): - clone: "f32[3, s16]" = torch.ops.aten.clone.default(primals_2); primals_2 = None - clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None + def forward(self, primals_1: "Sym(s0)", primals_2: "f32[3, s0]", primals_3: "f32[3, s0]", primals_4: "Sym(s0)", primals_5: "Sym(s0)"): + clone: "f32[3, s0]" = torch.ops.aten.clone.default(primals_2); primals_2 = None + clone_1: "f32[3, s0]" = torch.ops.aten.clone.default(primals_3); primals_3 = None - view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1]) - sym_numel_default: "Sym(3*s16)" = torch.ops.aten.sym_numel.default(clone) - view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1]) + view: "f32[3*s0]" = torch.ops.aten.view.default(clone, [-1]) + sym_numel_default: "Sym(3*s0)" = torch.ops.aten.sym_numel.default(clone) + view_1: "f32[3*s0]" = torch.ops.aten.view.default(clone_1, [-1]) return (clone, view, view_1, sym_numel_default, clone_1, primals_5) """, # noqa: B950 ) @@ -2332,9 +2332,9 @@ class GraphModule(torch.nn.Module): normalize_gm(bw[0].print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): - def forward(self, primals_5: "Sym(s16)", tangents_1: "f32[3*s16]", tangents_2: "f32[3*s16]"): - view_2: "f32[3, s16]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None - view_3: "f32[3, s16]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None + def forward(self, primals_5: "Sym(s0)", tangents_1: "f32[3*s0]", tangents_2: "f32[3*s0]"): + view_2: "f32[3, s0]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None + view_3: "f32[3, s0]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None return (None, view_2, view_3, primals_5, primals_5) """, # noqa: B950 ) @@ -2501,10 +2501,10 @@ class GraphModule(torch.nn.Module): normalize_gm(fw[0].print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): - def forward(self, primals_1: "Sym(s51)", primals_2: "Sym(s71)", primals_3: "Sym(s55)", primals_4: "f64[s64, s55]", primals_5: "i64[s51 + 1]", primals_6: "f32[s0, 0]", primals_7: "f32[s83, 0]", primals_8: "Sym(s51)", primals_9: "Sym(s55)", primals_10: "Sym(s55)"): - clone: "f64[s64, s55]" = torch.ops.aten.clone.default(primals_4); primals_4 = None + def forward(self, primals_1: "Sym(s2)", primals_2: "Sym(s3)", primals_3: "Sym(s1)", primals_4: "f64[s0, s1]", primals_5: "i64[s2 + 1]", primals_6: "f32[s6, 0]", primals_7: "f32[s7, 0]", primals_8: "Sym(s2)", primals_9: "Sym(s1)", primals_10: "Sym(s1)"): + clone: "f64[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None - mul: "f64[s64, s55]" = torch.ops.aten.mul.Tensor(clone, primals_1); clone = None + mul: "f64[s0, s1]" = torch.ops.aten.mul.Tensor(clone, primals_1); clone = None return (mul, primals_5, primals_6, primals_7, primals_8, primals_10, primals_10, primals_1, primals_8, primals_10) """, # noqa: B950 ) @@ -2513,8 +2513,8 @@ class GraphModule(torch.nn.Module): normalize_gm(bw[0].print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): - def forward(self, primals_1: "Sym(s51)", primals_8: "Sym(s51)", primals_10: "Sym(s55)", tangents_1: "f64[s64, s55]", tangents_2: "i64[s51 + 1]", tangents_3: "f32[s0, 0]", tangents_4: "f32[s83, 0]"): - mul_1: "f64[s64, s55]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = primals_1 = None + def forward(self, primals_1: "Sym(s2)", primals_8: "Sym(s2)", primals_10: "Sym(s1)", tangents_1: "f64[s0, s1]", tangents_2: "i64[s2 + 1]", tangents_3: "f32[s6, 0]", tangents_4: "f32[s7, 0]"): + mul_1: "f64[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = primals_1 = None return (None, None, None, mul_1, tangents_2, tangents_3, tangents_4, primals_8, primals_10, primals_10) """, # noqa: B950 ) @@ -2534,11 +2534,11 @@ class GraphModule(torch.nn.Module): normalize_gm(fw[0].print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): - def forward(self, primals_1: "Sym(s51)", primals_2: "Sym(s71)", primals_3: "Sym(s55)", primals_4: "f64[s64, s55]", primals_5: "i64[s51 + 1]", primals_6: "f32[s0, 0]", primals_7: "f32[s83, 0]", primals_8: "Sym(s51)", primals_9: "Sym(s55)", primals_10: "Sym(s55)"): - clone: "f64[s64, s55]" = torch.ops.aten.clone.default(primals_4); primals_4 = None + def forward(self, primals_1: "Sym(s2)", primals_2: "Sym(s3)", primals_3: "Sym(s1)", primals_4: "f64[s0, s1]", primals_5: "i64[s2 + 1]", primals_6: "f32[s6, 0]", primals_7: "f32[s7, 0]", primals_8: "Sym(s2)", primals_9: "Sym(s1)", primals_10: "Sym(s1)"): + clone: "f64[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None - cat: "f64[s64, 2*s55]" = torch.ops.aten.cat.default([clone, clone], 1); clone = None - add_2: "Sym(2*s55)" = primals_10 + primals_10 + cat: "f64[s0, 2*s1]" = torch.ops.aten.cat.default([clone, clone], 1); clone = None + add_2: "Sym(2*s1)" = primals_10 + primals_10 return (cat, primals_5, primals_6, primals_7, primals_8, add_2, add_2, primals_8, primals_10, add_2) """, # noqa: B950 ) @@ -2547,11 +2547,11 @@ class GraphModule(torch.nn.Module): normalize_gm(bw[0].print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): - def forward(self, primals_8: "Sym(s51)", primals_10: "Sym(s55)", add_2: "Sym(2*s55)", tangents_1: "f64[s64, 2*s55]", tangents_2: "i64[s51 + 1]", tangents_3: "f32[s0, 0]", tangents_4: "f32[s83, 0]"): - slice_1: "f64[s64, s55]" = torch.ops.aten.slice.Tensor(tangents_1, 1, 0, primals_10) - slice_2: "f64[s64, s55]" = torch.ops.aten.slice.Tensor(tangents_1, 1, primals_10, add_2); tangents_1 = add_2 = None + def forward(self, primals_8: "Sym(s2)", primals_10: "Sym(s1)", add_2: "Sym(2*s1)", tangents_1: "f64[s0, 2*s1]", tangents_2: "i64[s2 + 1]", tangents_3: "f32[s6, 0]", tangents_4: "f32[s7, 0]"): + slice_1: "f64[s0, s1]" = torch.ops.aten.slice.Tensor(tangents_1, 1, 0, primals_10) + slice_2: "f64[s0, s1]" = torch.ops.aten.slice.Tensor(tangents_1, 1, primals_10, add_2); tangents_1 = add_2 = None - add_4: "f64[s64, s55]" = torch.ops.aten.add.Tensor(slice_1, slice_2); slice_1 = slice_2 = None + add_4: "f64[s0, s1]" = torch.ops.aten.add.Tensor(slice_1, slice_2); slice_1 = slice_2 = None return (None, None, None, add_4, tangents_2, tangents_3, tangents_4, primals_8, primals_10, primals_10) """, # noqa: B950 ) @@ -2580,7 +2580,7 @@ class GraphModule(torch.nn.Module): normalize_gm(fw[0].print_readable(print_output=False)), """\ class (torch.nn.Module): - def forward(self, arg0_1: "Sym(s51)", arg1_1: "Sym(s71)", arg2_1: "Sym(s55)", arg3_1: "f64[9, s55]", arg4_1: "i64[s51 + 1]", arg5_1: "f32[s0, 0]", arg6_1: "f32[s83, 0]", arg7_1: "Sym(s51)", arg8_1: "Sym(s55)", arg9_1: "Sym(s55)"): + def forward(self, arg0_1: "Sym(s3)", arg1_1: "Sym(s4)", arg2_1: "Sym(s2)", arg3_1: "f64[9, s2]", arg4_1: "i64[s3 + 1]", arg5_1: "f32[s7, 0]", arg6_1: "f32[s8, 0]", arg7_1: "Sym(s3)", arg8_1: "Sym(s2)", arg9_1: "Sym(s2)"): randn: "f64[2, 5]" = torch.ops.aten.randn.default([2, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False) randn_1: "f64[3, 5]" = torch.ops.aten.randn.default([3, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False) randn_2: "f64[4, 5]" = torch.ops.aten.randn.default([4, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False) @@ -2594,13 +2594,13 @@ class (torch.nn.Module): zeros_1: "f32[2, 0]" = torch.ops.aten.zeros.default([2, 0], device = device(type='cpu'), pin_memory = False) zeros_2: "f32[4, 0]" = torch.ops.aten.zeros.default([4, 0], device = device(type='cpu'), pin_memory = False) - cat_2: "f64[9, s55 + 5]" = torch.ops.aten.cat.default([cat, arg3_1], 1); cat = arg3_1 = None + cat_2: "f64[9, s2 + 5]" = torch.ops.aten.cat.default([cat, arg3_1], 1); cat = arg3_1 = None - sin: "f64[9, s55 + 5]" = torch.ops.aten.sin.default(cat_2) - mul: "f64[9, s55 + 5]" = torch.ops.aten.mul.Tensor(sin, 3); sin = None + sin: "f64[9, s2 + 5]" = torch.ops.aten.sin.default(cat_2) + mul: "f64[9, s2 + 5]" = torch.ops.aten.mul.Tensor(sin, 3); sin = None - sym_size_int: "Sym(s55 + 5)" = torch.ops.aten.sym_size.int(cat_2, 1); cat_2 = None - sym_stride_int: "Sym(s55 + 5)" = torch.ops.aten.sym_stride.int(mul, 0) + sym_size_int: "Sym(s2 + 5)" = torch.ops.aten.sym_size.int(cat_2, 1); cat_2 = None + sym_stride_int: "Sym(s2 + 5)" = torch.ops.aten.sym_stride.int(mul, 0) return (mul, cat_1, zeros_1, zeros_2, sym_size_int, sym_stride_int) """, # noqa: B950 ) @@ -2757,10 +2757,10 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase, NestedTensorTestCase): norm_graph, """\ class GraphModule(torch.nn.Module): - def forward(self, s71: "Sym(s71)", L_nt_: "f64[3, s71, 5]"): + def forward(self, s1: "Sym(s1)", L_nt_: "f64[3, s1, 5]"): l_nt_ = L_nt_ - add: "f64[3, s71, 5]" = l_nt_ + 2; l_nt_ = None + add: "f64[3, s1, 5]" = l_nt_ + 2; l_nt_ = None return (add,) """, # noqa: B950 ) @@ -3254,27 +3254,27 @@ class GraphModule(torch.nn.Module): # varies based on the type of view guard_str = "\n".join(guards) if nt_view_name == "subclass_dense": - self.assertExpectedInline(guard_str, """Eq(s85 - 1, s77)""") + self.assertExpectedInline(guard_str, """Eq(s3 - 1, s0)""") elif nt_view_name == "dense_subclass_dense_subclass": self.assertExpectedInline( guard_str, """\ -Eq(s85 - 1, s77) -Eq(s80 - 1, s78) -Eq(s72, s71)""", +Eq(s5 - 1, s2) +Eq(s12 - 1, s7) +Eq(s11, s9)""", ) elif nt_view_name.startswith("base_is_nt_True"): self.assertExpectedInline( guard_str, - """Eq(s17 - 1, s83)""", + """Eq(s3 - 1, s0)""", ) else: self.assertExpectedInline( guard_str, """\ -Eq(s85 - 1, s64) -Eq(s80 - 1, s77) -Eq(s72, s71)""", +Eq(s4 - 1, s1) +Eq(s13 - 1, s8) +Eq(s12, s10)""", ) return gm diff --git a/test/export/test_export.py b/test/export/test_export.py index 87229d60b9c..3b1efd24a37 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -1560,7 +1560,7 @@ graph(): ) with self.assertRaisesRegex( error_type, - r"Real tensor propagation found an output size mismatch between fake shape s\d+ and real shape 4, " + r"Real tensor propagation found an output size mismatch between fake shape s1 and real shape 4, " r"at output\.size\(0\), for func: mylib.foo.default", ): export( @@ -2848,7 +2848,7 @@ def forward(self, p_linear_weight, p_linear_bias, x): with self.assertRaisesRegex( RuntimeError, "Expected input.*shape.*= 9 to be " - "of the form 2\\*s92, where s92 is an integer", + "of the form 2\\*s1, where s1 is an integer", ): ep.module()(torch.randn(9)) @@ -3506,11 +3506,8 @@ def forward(self, x): dynamic_shapes=({0: Dim("x")},), ) - # Since symbol names are based on hash of source names, and these differ across inference and - # training, we do range comparisons instead. self.assertEqual( - str(ep_for_training.range_constraints.values()), - str(ep_for_real.range_constraints.values()), + str(ep_for_training.range_constraints), str(ep_for_real.range_constraints) ) def test_export_for_training_with_container_type(self): @@ -4388,7 +4385,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): em.module()(torch.randn(4, 3)) with self.assertRaisesRegex( RuntimeError, - r"Runtime assertion failed for expression Eq\(Mod\(s27\*s77, s77 \- 1\), 0\)", + r"Runtime assertion failed for expression Eq\(Mod\(s0\*s1, s0 \- 1\), 0\)", ): em.module()(torch.randn(4, 5)) @@ -4399,7 +4396,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): x = torch.randn(3, 5) with self.assertRaisesRegex( RuntimeError, - "Expected.*shape\\[1\\] = 5 to be of the form 2\\*s33, where s33 is an integer", + "Expected.*shape\\[1\\] = 5 to be of the form 2\\*s1, where s1 is an integer", ): em.module()(x) @@ -4958,14 +4955,11 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): ) self.assertEqual( [ - # First dimension varies across strict and non-strict - # since the source names are different, resulting in - # different symbol names. - str(node.meta["val"].shape[1:]) + str(node.meta["val"].shape) for node in efoo.graph_module.graph.nodes if node.op == "placeholder" ], - ["torch.Size([2, 3])", "torch.Size([3, 4])"], + ["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"], ) @testing.expectedFailureCppSerDes @@ -5103,10 +5097,14 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): "y": (batch, size, size), }, ) - - for node in efoo.graph_module.graph.nodes: - if node.op == "placeholder": - self.assertEqual(node.meta["val"].shape[1], node.meta["val"].shape[2]) + self.assertEqual( + [ + str(node.meta["val"].shape) + for node in efoo.graph_module.graph.nodes + if node.op == "placeholder" + ], + ["torch.Size([s0, s1, s1])", "torch.Size([s0, s1, s1])"], + ) self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape) # pass dynamic shapes of inputs [multiple, mostly distinct] @@ -5117,14 +5115,13 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): inputs, dynamic_shapes={"x": (batch, M, K), "y": (batch, K, N)}, ) - placeholders = [ - node.meta["val"].shape - for node in efoo.graph_module.graph.nodes - if node.op == "placeholder" - ] self.assertEqual( - placeholders[0][2], - placeholders[1][1], + [ + str(node.meta["val"].shape) + for node in efoo.graph_module.graph.nodes + if node.op == "placeholder" + ], + ["torch.Size([s0, s1, s2])", "torch.Size([s0, s2, s5])"], ) self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape) @@ -5141,14 +5138,11 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): ) self.assertEqual( [ - # First dimension varies across strict and non-strict - # since the source names are different, resulting in - # different symbol names. - str(node.meta["val"].shape[1:]) + str(node.meta["val"].shape) for node in efoo.graph_module.graph.nodes if node.op == "placeholder" ], - ["torch.Size([2, 3])", "torch.Size([3, 4])"], + ["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"], ) self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape) @@ -5165,14 +5159,11 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): ) self.assertEqual( [ - # First dimension varies across strict and non-strict - # since the source names are different, resulting in - # different symbol names. - str(node.meta["val"].shape[1:]) + str(node.meta["val"].shape) for node in efoo.graph_module.graph.nodes if node.op == "placeholder" ], - ["torch.Size([2, 3])", "torch.Size([3, 4])"], + ["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"], ) self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape) @@ -5482,7 +5473,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): if node.op == "placeholder" ] self.assertEqual(len(input_shapes), 9) - self.assertTrue(all(shape == "torch.Size([s3])" for shape in input_shapes)) + self.assertTrue(all(shape == "torch.Size([s0])" for shape in input_shapes)) def test_error_does_not_reference_eager_fallback(self): class Module(torch.nn.Module): @@ -11146,7 +11137,7 @@ def forward(self, x, y): self.assertEqual(out2.shape, torch.ones(11, 4, 3).shape) with self.assertRaisesRegex( RuntimeError, - r"Runtime assertion failed for expression Eq\(Mod\(s27\*s77, 4\*s77 \- 4\), 0\) on node 'eq.*'", + r"Runtime assertion failed for expression Eq\(Mod\(s0\*s1, 4\*s0 \- 4\), 0\) on node 'eq.*'", ): ep.module()(torch.randn(8, 8)) # fail @@ -11178,7 +11169,7 @@ def forward(self, x, y): self.assertEqual(out2.shape, torch.ones(40).shape) with self.assertRaisesRegex( RuntimeError, - r"Runtime assertion failed for expression Eq\((.*)\) on node '.*'", + r"Runtime assertion failed for expression Eq\(s0\*s1, s2\*s3\) on node 'eq.*'", ): # fail only at runtime ep.module()(torch.randn(5, 8), torch.randn(4, 5), torch.randn(30)) # fail @@ -11205,7 +11196,7 @@ def forward(self, x, y): self.assertEqual(out1.shape, torch.ones(126).shape) with self.assertRaisesRegex( RuntimeError, - r"Runtime assertion failed for expression Eq\((.*)\) on node '.*'", + r"Runtime assertion failed for expression Eq\(s0\*s1\*s2, s3\) on node 'eq.*'", ): # fail only at runtime ep.module()(torch.randn(4, 3, 2), torch.randn(10)) # fail @@ -11286,12 +11277,12 @@ def forward(self, x, y): ) with self.assertRaisesRegex( RuntimeError, - r"Runtime assertion failed for expression Ne\(s77, 20\)", + r"Runtime assertion failed for expression Ne\(s0, 20\)", ): ep.module()(torch.randn(20, 20, 16)) with self.assertRaisesRegex( RuntimeError, - r"Runtime assertion failed for expression Ne\(Mod\(s77, 20\), 0\)", + r"Runtime assertion failed for expression Ne\(Mod\(s0, 20\), 0\)", ): ep.module()(torch.randn(400, 20, 16)) ep.module()(torch.randn(42, 20, 16)) @@ -11329,17 +11320,17 @@ def forward(self, x, y): self.assertEqual(out1.shape, torch.ones(27).shape) with self.assertRaisesRegex( RuntimeError, - r"Runtime assertion failed for expression Ne\(s77, s17\)", + r"Runtime assertion failed for expression Ne\(s0, s1\)", ): # fail only at runtime ep.module()(torch.randn(4), torch.randn(4)) # fail with self.assertRaisesRegex( RuntimeError, - r"Runtime assertion failed for expression Ne\(s77, s17\**3\)", + r"Runtime assertion failed for expression Ne\(s0, s1\**3\)", ): ep.module()(torch.randn(64), torch.randn(4)) # fail with self.assertRaisesRegex( RuntimeError, - r"Runtime assertion failed for expression Eq\(s77\**2, 3\*s17\)", + r"Runtime assertion failed for expression Eq\(s0\**2, 3\*s1\)", ): ep.module()(torch.randn(10), torch.randn(9)) # fail diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index f5a324c7afd..81fe2867620 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -539,12 +539,8 @@ def forward(self, x): ep.range_constraints[symint] = ValueRanges(lower=lower_range, upper=upper_range) serialized = ExportedProgramSerializer().serialize(ep) - self.assertEqual( - serialized.exported_program.range_constraints["s77"].min_val, 2 - ) - self.assertEqual( - serialized.exported_program.range_constraints["s77"].max_val, 3 - ) + self.assertEqual(serialized.exported_program.range_constraints["s0"].min_val, 2) + self.assertEqual(serialized.exported_program.range_constraints["s0"].max_val, 3) def test_kwargs_default(self) -> None: """ diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 14804de00c2..0343f1ed772 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -5941,8 +5941,8 @@ class TestAOTModuleSimplified(AOTTestCase): self.assertExpectedInline( shape_env.format_guards(), """\ - - Eq(s49, 20) - - Eq(s70, 30)""", + - Eq(s1, 20) + - Eq(s2, 30)""", ) assert torch.allclose(ref[0], res[0]) diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 0e3f39eb226..a80a77f98a6 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -4553,10 +4553,10 @@ def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_bo gm.code.strip("\n"), """\ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1): - sym_size_int = torch.ops.aten.sym_size.int(arg3_1, 1) + sym_size_int = torch.ops.aten.sym_size.int(arg2_1, 0) sym_size_int_1 = torch.ops.aten.sym_size.int(arg2_1, 1) - sym_size_int_2 = torch.ops.aten.sym_size.int(arg2_1, 0) - sym_size_int_3 = torch.ops.aten.sym_size.int(arg3_1, 0) + sym_size_int_2 = torch.ops.aten.sym_size.int(arg3_1, 0) + sym_size_int_3 = torch.ops.aten.sym_size.int(arg3_1, 1) while_loop_cond_graph_0 = self.while_loop_cond_graph_0 while_loop_body_graph_0 = self.while_loop_body_graph_0 while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), (sym_size_int, sym_size_int_1, sym_size_int_2, sym_size_int_3)); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = None @@ -4719,10 +4719,10 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1 def forward(self, a_1, b_1): sum_1 = torch.ops.aten.sum.default(a_1) gt = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None - sym_size_int = torch.ops.aten.sym_size.int(a_1, 1) - sym_size_int_1 = torch.ops.aten.sym_size.int(b_1, 0) - sym_size_int_2 = torch.ops.aten.sym_size.int(b_1, 1) - sym_size_int_3 = torch.ops.aten.sym_size.int(a_1, 0) + sym_size_int = torch.ops.aten.sym_size.int(a_1, 0) + sym_size_int_1 = torch.ops.aten.sym_size.int(a_1, 1) + sym_size_int_2 = torch.ops.aten.sym_size.int(b_1, 0) + sym_size_int_3 = torch.ops.aten.sym_size.int(b_1, 1) true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (a_1, b_1, sym_size_int, sym_size_int_1, sym_size_int_2, sym_size_int_3)); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = None @@ -5856,7 +5856,7 @@ def forward(self, x_1): sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1) true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 - cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int_1, sym_size_int)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int_1 = sym_size_int = None + cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int, sym_size_int_1)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int = sym_size_int_1 = None getitem = cond[0]; cond = None return getitem""", # noqa: B950 ) @@ -5969,7 +5969,7 @@ def forward(self, x_1): false_graph_0 = self.false_graph_0 _tensor_constant0 = self._tensor_constant0 _tensor_constant1 = self._tensor_constant1 - cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, _tensor_constant0, sym_size_int_1, sym_size_int, _tensor_constant1)); eq = true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = sym_size_int_1 = sym_size_int = _tensor_constant1 = None + cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, _tensor_constant0, sym_size_int, sym_size_int_1, _tensor_constant1)); eq = true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = sym_size_int = sym_size_int_1 = _tensor_constant1 = None getitem = cond[0]; cond = None return getitem""", # noqa: B950 ) @@ -6209,7 +6209,7 @@ def forward(self, x_1): sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1) true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 - cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int_1, sym_size_int)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int_1 = sym_size_int = None + cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int, sym_size_int_1)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int = sym_size_int_1 = None getitem = cond[0]; cond = None return getitem""", # noqa: B950 ) @@ -6558,14 +6558,14 @@ def forward(self, l_inp_, l_tmp_): self.assertExpectedInline( backend.graphs[0].code.strip(), """\ -def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor, L_self_num : torch.SymInt): +def forward(self, s0 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor, L_self_num : torch.SymInt): l_a_ = L_a_ l_b_ = L_b_ l_self_num = L_self_num tensor = torch.tensor([True]) cond_true_0 = self.cond_true_0 cond_false_0 = self.cond_false_0 - cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, (l_a_, l_b_, l_self_num, s97)); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = l_self_num = s97 = None + cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, (l_a_, l_b_, l_self_num, s0)); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = l_self_num = s0 = None getitem = cond[0]; cond = None return (getitem,)""", # noqa: B950 ) @@ -6753,10 +6753,10 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_ """\ class GraphModule(torch.nn.Module): def forward(self, x): - x: "f32[s35, 3]"; + x: "f32[s0, 3]"; x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) - sym_size_int_1: "Sym(s35)" = torch.ops.aten.sym_size.int(x, 0) + sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0) while_loop_cond_graph_0 = self.while_loop_cond_graph_0 while_loop_body_graph_0 = self.while_loop_body_graph_0 @@ -6770,27 +6770,27 @@ class GraphModule(torch.nn.Module): gt_1: "Sym(u1 > 0)" = getitem_2 > 0 _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u1 on node 'gt_1'"); gt_1 = _assert_scalar_default_1 = None - getitem_1: "f32[s35, 3]" = while_loop[1]; while_loop = None + getitem_1: "f32[s0, 3]" = while_loop[1]; while_loop = None add: "Sym(u1 + 1)" = getitem_2 + 1 - add_1: "f32[s35, 3]" = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = None + add_1: "f32[s0, 3]" = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = None - lt: "Sym(u1 < s35)" = getitem_2 < sym_size_int_1; sym_size_int_1 = None + lt: "Sym(u1 < s0)" = getitem_2 < sym_size_int_1; sym_size_int_1 = None mul: "Sym(2*u1)" = getitem_2 * 2; getitem_2 = None ones: "f32[2*u1]" = torch.ops.aten.ones.default([mul], device = device(type='cpu'), pin_memory = False); mul = None return pytree.tree_unflatten((add, add_1, lt, ones), self._out_spec) class while_loop_cond_graph_0(torch.nn.Module): - def forward(self, it_1: "Sym(u0)", x_1: "f32[s35, 3]"): - sym_size_int: "Sym(s35)" = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None - lt: "Sym(u0 < s35)" = it_1 < sym_size_int; it_1 = sym_size_int = None + def forward(self, it_1: "Sym(u0)", x_1: "f32[s0, 3]"): + sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None + lt: "Sym(u0 < s0)" = it_1 < sym_size_int; it_1 = sym_size_int = None return lt class while_loop_body_graph_0(torch.nn.Module): - def forward(self, it_1: "Sym(u0)", x_1: "f32[s35, 3]"): - clone: "f32[s35, 3]" = torch.ops.aten.clone.default(x_1); x_1 = None + def forward(self, it_1: "Sym(u0)", x_1: "f32[s0, 3]"): + clone: "f32[s0, 3]" = torch.ops.aten.clone.default(x_1); x_1 = None select: "f32[3]" = torch.ops.aten.select.int(clone, 0, it_1) select_1: "f32[3]" = torch.ops.aten.select.int(clone, 0, it_1) add: "f32[3]" = torch.ops.aten.add.Tensor(select_1, it_1); select_1 = None @@ -6820,12 +6820,12 @@ class GraphModule(torch.nn.Module): normalize_gm(backend.graphs[0].print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): - def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]"): + def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_x_: "f32[s0, s1]"): l_x_ = L_x_ cond_fn_0 = self.cond_fn_0 body_fn_0 = self.body_fn_0 - while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (0, l_x_), (s27, s77)); cond_fn_0 = body_fn_0 = l_x_ = s27 = None + while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (0, l_x_), (s0, s1)); cond_fn_0 = body_fn_0 = l_x_ = s1 = None getitem_4: "Sym(u1)" = while_loop[0] @@ -6835,49 +6835,49 @@ class GraphModule(torch.nn.Module): gt_1: "Sym(u1 > 0)" = getitem_4 > 0 _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u1 on node 'gt_1'"); gt_1 = _assert_scalar_default_1 = None - out_x: "f32[s77, s27]" = while_loop[1]; while_loop = None + out_x: "f32[s0, s1]" = while_loop[1]; while_loop = None add: "Sym(u1 + 1)" = getitem_4 + 1 - add_1: "f32[s77, s27]" = getitem_4 + out_x; out_x = None + add_1: "f32[s0, s1]" = getitem_4 + out_x; out_x = None - lt: "Sym(u1 < s77)" = getitem_4 < s77; s77 = None + lt: "Sym(u1 < s0)" = getitem_4 < s0; s0 = None mul: "Sym(2*u1)" = getitem_4 * 2; getitem_4 = None ones: "f32[2*u1]" = torch.ones(mul); mul = None return (add, add_1, lt, ones) class cond_fn_0(torch.nn.Module): - def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s77, s27]", s27, s77): - s27_1 = s27 - s77_1 = s77 + def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s0, s1]", s0, s1): + s0_1 = s0 + s1_1 = s1 size = l_x_.size(); l_x_ = None - getitem: "Sym(s77)" = size[0] - getitem_1: "Sym(s27)" = size[1]; size = getitem_1 = None - lt: "Sym(u0 < s77)" = unbacked_symint < getitem; unbacked_symint = getitem = None + getitem: "Sym(s0)" = size[0] + getitem_1: "Sym(s1)" = size[1]; size = getitem_1 = None + lt: "Sym(u0 < s0)" = unbacked_symint < getitem; unbacked_symint = getitem = None return lt class body_fn_0(torch.nn.Module): - def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s77, s27]", s27, s77): - s27_1 = s27 - s77_1 = s77 + def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s0, s1]", s0, s1): + s0_1 = s0 + s1_1 = s1 - x_clone: "f32[s77, s27]" = l_x_.clone() + x_clone: "f32[s0, s1]" = l_x_.clone() ge: "Sym(u0 >= 0)" = unbacked_symint >= 0 _check = torch._check(ge); ge = _check = None size = l_x_.size(); l_x_ = None - getitem: "Sym(s77)" = size[0] - getitem_1: "Sym(s27)" = size[1]; size = getitem_1 = None - lt: "Sym(u0 < s77)" = unbacked_symint < getitem; getitem = None + getitem: "Sym(s0)" = size[0] + getitem_1: "Sym(s1)" = size[1]; size = getitem_1 = None + lt: "Sym(u0 < s0)" = unbacked_symint < getitem; getitem = None _check_1 = torch._check(lt); lt = _check_1 = None - select: "f32[s27]" = x_clone.select(0, unbacked_symint) - select_1: "f32[s27]" = x_clone.select(0, unbacked_symint) - add: "f32[s27]" = select_1 + unbacked_symint; select_1 = None - copy_: "f32[s27]" = select.copy_(add); select = add = copy_ = None + select: "f32[s1]" = x_clone.select(0, unbacked_symint) + select_1: "f32[s1]" = x_clone.select(0, unbacked_symint) + add: "f32[s1]" = select_1 + unbacked_symint; select_1 = None + copy_: "f32[s1]" = select.copy_(add); select = add = copy_ = None add_1: "Sym(u0 + 1)" = unbacked_symint + 1; unbacked_symint = None return (add_1, x_clone) @@ -7048,12 +7048,12 @@ class GraphModule(torch.nn.Module): """\ class GraphModule(torch.nn.Module): def forward(self, x): - x: "f32[s77, 3]"; + x: "f32[s0, 3]"; x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) - sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x, 0) + sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0) - sin: "f32[s77, 3]" = torch.ops.aten.sin.default(x); x = None + sin: "f32[s0, 3]" = torch.ops.aten.sin.default(x); x = None while_loop_cond_graph_0 = self.while_loop_cond_graph_0 while_loop_body_graph_0 = self.while_loop_body_graph_0 @@ -7065,19 +7065,19 @@ class GraphModule(torch.nn.Module): getitem_9: "Sym(u8)" = while_loop[3] getitem_10: "Sym(u9)" = while_loop[4] - getitem_5: "f32[s77, 3]" = while_loop[5]; while_loop = None + getitem_5: "f32[s0, 3]" = while_loop[5]; while_loop = None add: "Sym(u7 + 1)" = getitem_8 + 1 add_1: "Sym(u8 + 1)" = getitem_9 + 1 add_2: "Sym(u9 + 1)" = getitem_10 + 1 - add_3: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_8); getitem_8 = None - add_4: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_9); getitem_9 = None - add_5: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_10); getitem_10 = None + add_3: "f32[s0, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_8); getitem_8 = None + add_4: "f32[s0, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_9); getitem_9 = None + add_5: "f32[s0, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_10); getitem_10 = None return pytree.tree_unflatten((getitem_6, getitem_7, add, add_1, add_2, add_3, add_4, add_5, getitem_5), self._out_spec) class while_loop_cond_graph_0(torch.nn.Module): - def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s77, 3]"): + def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s0, 3]"): mul: "Sym(u17*u18)" = arg2_1 * arg3_1; arg2_1 = arg3_1 = None mul_1: "Sym(u17*u18*u19)" = mul * arg4_1; mul = arg4_1 = None mul_2: "Sym(u15*u16)" = arg0_1 * arg1_1; arg0_1 = arg1_1 = None @@ -7085,7 +7085,7 @@ class GraphModule(torch.nn.Module): return lt class while_loop_body_graph_0(torch.nn.Module): - def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s77, 3]"): + def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s0, 3]"): add: "Sym(u15 + 1)" = arg0_1 + 1; arg0_1 = None add_1: "Sym(u16 + 1)" = arg1_1 + 1; arg1_1 = None @@ -7093,7 +7093,7 @@ class GraphModule(torch.nn.Module): add_3: "Sym(u18 + 1)" = arg3_1 + 1; arg3_1 = None add_4: "Sym(u19 + 1)" = arg4_1 + 1; arg4_1 = None - add_5: "f32[s77, 3]" = torch.ops.aten.add.Tensor(arg5_1, 1); arg5_1 = None + add_5: "f32[s0, 3]" = torch.ops.aten.add.Tensor(arg5_1, 1); arg5_1 = None return (add, add_1, add_2, add_3, add_4, add_5) """, # noqa: B950 ) @@ -7119,14 +7119,14 @@ class GraphModule(torch.nn.Module): normalize_gm(backend.graphs[0].print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): - def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]"): + def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_x_: "f32[s0, s1]"): l_x_ = L_x_ - child: "f32[s77, s27]" = l_x_.sin(); l_x_ = None + child: "f32[s0, s1]" = l_x_.sin(); l_x_ = None cond_fn_0 = self.cond_fn_0 body_fn_0 = self.body_fn_0 - while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (s77, s27, 2, 2, 3, child), (s27, s77)); cond_fn_0 = body_fn_0 = s77 = s27 = child = None + while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (s0, s1, 2, 2, 3, child), (s0, s1)); cond_fn_0 = body_fn_0 = s0 = s1 = child = None getitem_10: "Sym(u5)" = while_loop[0] getitem_11: "Sym(u6)" = while_loop[1] @@ -7134,21 +7134,21 @@ class GraphModule(torch.nn.Module): getitem_13: "Sym(u8)" = while_loop[3] getitem_14: "Sym(u9)" = while_loop[4] - out_x: "f32[s77, s27]" = while_loop[5]; while_loop = None + out_x: "f32[s0, s1]" = while_loop[5]; while_loop = None add: "Sym(u7 + 1)" = getitem_12 + 1 add_1: "Sym(u8 + 1)" = getitem_13 + 1 add_2: "Sym(u9 + 1)" = getitem_14 + 1 - add_3: "f32[s77, s27]" = getitem_12 + out_x; getitem_12 = None - add_4: "f32[s77, s27]" = getitem_13 + out_x; getitem_13 = None - add_5: "f32[s77, s27]" = getitem_14 + out_x; getitem_14 = None + add_3: "f32[s0, s1]" = getitem_12 + out_x; getitem_12 = None + add_4: "f32[s0, s1]" = getitem_13 + out_x; getitem_13 = None + add_5: "f32[s0, s1]" = getitem_14 + out_x; getitem_14 = None return (getitem_10, getitem_11, add, add_1, add_2, add_3, add_4, add_5, out_x) class cond_fn_0(torch.nn.Module): - def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s77, s27]", s27, s77): - s27_1 = s27 - s77_1 = s77 + def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s0, s1]", s0, s1): + s0_1 = s0 + s1_1 = s1 mul: "Sym(u2*u3)" = unbacked_symint_1 * unbacked_symint_2; unbacked_symint_1 = unbacked_symint_2 = None mul_1: "Sym(u2*u3*u4)" = mul * unbacked_symint_3; mul = unbacked_symint_3 = None @@ -7157,9 +7157,9 @@ class GraphModule(torch.nn.Module): return lt class body_fn_0(torch.nn.Module): - def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s77, s27]", s27, s77): - s27_1 = s27 - s77_1 = s77 + def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s0, s1]", s0, s1): + s0_1 = s0 + s1_1 = s1 add: "Sym(u0 + 1)" = unbacked_symint + 1; unbacked_symint = None add_1: "Sym(u1 + 1)" = unbacked_symint_0 + 1; unbacked_symint_0 = None @@ -7168,7 +7168,7 @@ class GraphModule(torch.nn.Module): add_3: "Sym(u3 + 1)" = unbacked_symint_2 + 1; unbacked_symint_2 = None add_4: "Sym(u4 + 1)" = unbacked_symint_3 + 1; unbacked_symint_3 = None - child_1: "f32[s77, s27]" = child + 1; child = None + child_1: "f32[s0, s1]" = child + 1; child = None return (add, add_1, add_2, add_3, add_4, child_1) """, # noqa: B950 ) @@ -7336,30 +7336,30 @@ class GraphModule(torch.nn.Module): """\ class GraphModule(torch.nn.Module): def forward(self, x, y, z): - x: "f32[s35, 3]"; y: "f32[s58]"; z: "f32[s35, 3]"; + x: "f32[s0, 3]"; y: "f32[s1]"; z: "f32[s0, 3]"; x, y, z, = fx_pytree.tree_flatten_spec(([x, y, z], {}), self._in_spec) - sym_size_int_3: "Sym(s35)" = torch.ops.aten.sym_size.int(x, 0) - sym_size_int_4: "Sym(s58)" = torch.ops.aten.sym_size.int(y, 0); y = None + sym_size_int_3: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0) + sym_size_int_4: "Sym(s1)" = torch.ops.aten.sym_size.int(y, 0); y = None - gt: "Sym(s35 > 5)" = sym_size_int_3 > 5 + gt: "Sym(s0 > 5)" = sym_size_int_3 > 5 true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (x, sym_size_int_4, sym_size_int_3, z)); gt = true_graph_0 = false_graph_0 = x = sym_size_int_4 = sym_size_int_3 = z = None - getitem: "f32[s35, 3]" = cond[0]; cond = None + getitem: "f32[s0, 3]" = cond[0]; cond = None return pytree.tree_unflatten((getitem,), self._out_spec) class true_graph_0(torch.nn.Module): - def forward(self, x: "f32[s35, 3]", sym_size_int_4: "Sym(s58)", sym_size_int_3: "Sym(s35)", z: "f32[s35, 3]"): - add: "f32[s35, 3]" = torch.ops.aten.add.Tensor(x, sym_size_int_4); x = sym_size_int_4 = None + def forward(self, x: "f32[s0, 3]", sym_size_int_4: "Sym(s1)", sym_size_int_3: "Sym(s0)", z: "f32[s0, 3]"): + add: "f32[s0, 3]" = torch.ops.aten.add.Tensor(x, sym_size_int_4); x = sym_size_int_4 = None return (add,) class false_graph_0(torch.nn.Module): - def forward(self, x: "f32[s35, 3]", sym_size_int_4: "Sym(s58)", sym_size_int_3: "Sym(s35)", z: "f32[s35, 3]"): - mul: "f32[s35, 3]" = torch.ops.aten.mul.Tensor(z, sym_size_int_3); z = sym_size_int_3 = None + def forward(self, x: "f32[s0, 3]", sym_size_int_4: "Sym(s1)", sym_size_int_3: "Sym(s0)", z: "f32[s0, 3]"): + mul: "f32[s0, 3]" = torch.ops.aten.mul.Tensor(z, sym_size_int_3); z = sym_size_int_3 = None - add: "f32[s35, 3]" = torch.ops.aten.add.Tensor(x, mul); x = mul = None + add: "f32[s0, 3]" = torch.ops.aten.add.Tensor(x, mul); x = mul = None return (add,) """, # noqa: B950 ) @@ -7522,7 +7522,7 @@ class GraphModule(torch.nn.Module): normalize_gm(bk.graphs[0].print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): - def forward(self, s17: "Sym(s17)", s94: "Sym(s94)", L_y_: "f32[s17, s94]", L_z_: "f32[s17, s94]", L_x_: "f32[s17, s94]"): + def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_y_: "f32[s0, s1]", L_z_: "f32[s0, s1]", L_x_: "f32[s0, s1]"): l_y_ = L_y_ l_z_ = L_z_ l_x_ = L_x_ @@ -7532,39 +7532,39 @@ class GraphModule(torch.nn.Module): cond_true_0 = self.cond_true_0 cond_false_0 = self.cond_false_0 - cond = torch.ops.higher_order.cond(gt, cond_true_0, cond_false_0, (l_x_, s94, s17, s17, l_z_)); gt = cond_true_0 = cond_false_0 = l_x_ = s94 = s17 = l_z_ = None + cond = torch.ops.higher_order.cond(gt, cond_true_0, cond_false_0, (l_x_, s1, s0, s0, l_z_)); gt = cond_true_0 = cond_false_0 = l_x_ = s1 = s0 = l_z_ = None - getitem_5: "f32[u0, s94]" = cond[0] + getitem_5: "f32[u0, s1]" = cond[0] sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(getitem_5, 0); getitem_5 = None _check_is_size = torch._check_is_size(sym_size_int); _check_is_size = None ge: "Sym(u0 >= 0)" = sym_size_int >= 0; sym_size_int = None _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None - ret: "f32[u0, s94]" = cond[0]; cond = None + ret: "f32[u0, s1]" = cond[0]; cond = None sum_2: "f32[]" = l_y_.sum(); l_y_ = None - sub: "f32[u0, s94]" = sum_2 - ret; sum_2 = ret = None + sub: "f32[u0, s1]" = sum_2 - ret; sum_2 = ret = None return (sub,) class cond_true_0(torch.nn.Module): - def forward(self, l_x_, s94, s17_true_branch, getitem_2_false_branch, l_z__false_branch): + def forward(self, l_x_, s1, s0_true_branch, getitem_2_false_branch, l_z__false_branch): l_x__1 = l_x_ - s94_1 = s94 + s1_1 = s1 - add: "f32[s17, s94]" = l_x__1 + s17_true_branch; l_x__1 = s17_true_branch = None - getitem: "f32[s17 - 2, s94]" = add[slice(2, None, None)]; add = None - clone: "f32[s17 - 2, s94]" = getitem.clone(); getitem = None + add: "f32[s0, s1]" = l_x__1 + s0_true_branch; l_x__1 = s0_true_branch = None + getitem: "f32[s0 - 2, s1]" = add[slice(2, None, None)]; add = None + clone: "f32[s0 - 2, s1]" = getitem.clone(); getitem = None return (clone,) class cond_false_0(torch.nn.Module): - def forward(self, l_x_, s94, s17_true_branch, getitem_2_false_branch, l_z__false_branch): + def forward(self, l_x_, s1, s0_true_branch, getitem_2_false_branch, l_z__false_branch): l_x__1 = l_x_ - s94_1 = s94 + s1_1 = s1 - mul: "f32[s17, s94]" = getitem_2_false_branch * l_z__false_branch; getitem_2_false_branch = l_z__false_branch = None - add: "f32[s17, s94]" = l_x__1 + mul; l_x__1 = mul = None - getitem: "f32[2, s94]" = add[slice(None, 2, None)]; add = None - clone: "f32[2, s94]" = getitem.clone(); getitem = None + mul: "f32[s0, s1]" = getitem_2_false_branch * l_z__false_branch; getitem_2_false_branch = l_z__false_branch = None + add: "f32[s0, s1]" = l_x__1 + mul; l_x__1 = mul = None + getitem: "f32[2, s1]" = add[slice(None, 2, None)]; add = None + clone: "f32[2, s1]" = getitem.clone(); getitem = None return (clone,) """, # noqa: B950 ) diff --git a/test/inductor/test_auto_functionalize.py b/test/inductor/test_auto_functionalize.py index abe335ffa74..7033d362170 100644 --- a/test/inductor/test_auto_functionalize.py +++ b/test/inductor/test_auto_functionalize.py @@ -403,10 +403,10 @@ arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None self.assertExpectedInline( post_grad_graphs, """\ -def forward(self, arg0_1: "Sym(s72)", arg1_1: "f32[s72][1]cpu", arg2_1: "f32[s72][1]cpu", arg3_1: "f32[s72][1]cpu", arg4_1: "f32[s72][1]cpu", arg5_1: "f32[s72][1]cpu"): +def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu", arg3_1: "f32[s0][1]cpu", arg4_1: "f32[s0][1]cpu", arg5_1: "f32[s0][1]cpu"): foo_default = torch.ops.mylib.foo.default(arg3_1, [arg4_1, arg5_1], arg2_1, 2, arg1_1); arg4_1 = arg5_1 = arg1_1 = foo_default = None - copy_: "f32[s72][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy_ = None - copy__1: "f32[s72][1]cpu" = torch.ops.aten.copy_.default(arg3_1, arg3_1); arg3_1 = copy__1 = None + copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy_ = None + copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg3_1, arg3_1); arg3_1 = copy__1 = None return ()""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, @@ -563,13 +563,13 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3 self.assertExpectedInline( graph_aot, """\ -def forward(self, arg0_1: "Sym(s17)", arg1_1: "f32[s17][1]cpu", arg2_1: "f32[s17][1]cpu"): +def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu"): auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg2_1, arg1_1]) - getitem_1: "f32[s17][1]cpu" = auto_functionalized_v2[1] - getitem_2: "f32[s17][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None - add: "f32[s17][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2) - copy_: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy_ = None - copy__1: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_1); arg2_1 = getitem_1 = copy__1 = None + getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1] + getitem_2: "f32[s0][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None + add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2) + copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy_ = None + copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_1); arg2_1 = getitem_1 = copy__1 = None return (add,)""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, @@ -595,11 +595,11 @@ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"): self.assertExpectedInline( graph_inductor, """\ -def forward(self, arg0_1: "Sym(s17)", arg1_1: "f32[s17][1]cpu", arg2_1: "f32[s17][1]cpu"): +def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu"): foo_default = torch.ops.mylib.foo.default(arg2_1, arg1_1); foo_default = None - add: "f32[s17][1]cpu" = torch.ops.aten.add.Tensor(arg2_1, arg1_1) - copy_: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None - copy__1: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None + add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(arg2_1, arg1_1) + copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None + copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None return (add,)""", ignore_comments=True, ignore_empty_lines=True, @@ -663,10 +663,10 @@ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"): self.assertExpectedInline( graph_aot, """\ -def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"): +def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"): auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg1_1]) - getitem_1: "f32[s77][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None - copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy_ = None + getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None + copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy_ = None return ()""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, @@ -691,11 +691,11 @@ def forward(self, arg0_1: "f32[2][1]cpu"): self.assertExpectedInline( graph_inductor, """\ -def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"): +def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"): as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 0) as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 1) foo_default = torch.ops.mylib.foo.default(as_strided_default, as_strided_default_1); as_strided_default = as_strided_default_1 = foo_default = None - copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None + copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None return ()""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, @@ -1291,14 +1291,14 @@ def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"): self.assertExpectedInline( graph_aot, """\ -def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77, s77][s77, 1]cpu"): +def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0, s0][s0, 1]cpu"): floordiv: "Sym(0)" = 0 // arg0_1; arg0_1 = None add_6: "Sym(2)" = floordiv + 2 auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_slice_dim = 0, _x_slice_start = floordiv, _x_slice_end = add_6, _y_base_index = 0, _y_slice_dim = 1, _y_slice_start = 3, _y_slice_end = 4, _all_bases = [arg1_1]); floordiv = add_6 = None - getitem_1: "f32[s77, s77][s77, 1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None - copy_: "f32[s77, s77][s77, 1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None - slice_3: "f32[2, s77][s77, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_1, 0, 0, 2) - slice_4: "f32[s77, 1][s77, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_1, 1, 3, 4); getitem_1 = None + getitem_1: "f32[s0, s0][s0, 1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None + copy_: "f32[s0, s0][s0, 1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None + slice_3: "f32[2, s0][s0, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_1, 0, 0, 2) + slice_4: "f32[s0, 1][s0, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_1, 1, 3, 4); getitem_1 = None return (slice_3, slice_4)""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, @@ -1324,13 +1324,13 @@ def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"): self.assertExpectedInline( graph_inductor, """\ -def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77, s77][s77, 1]cpu"): - slice_tensor: "f32[2, s77][s77, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, 0, 2) - slice_tensor_1: "f32[s77, 1][s77, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 1, 3, 4) +def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0, s0][s0, 1]cpu"): + slice_tensor: "f32[2, s0][s0, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, 0, 2) + slice_tensor_1: "f32[s0, 1][s0, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 1, 3, 4) foo_default = torch.ops.mylib.foo.default(slice_tensor, slice_tensor_1); slice_tensor = slice_tensor_1 = foo_default = None - copy_: "f32[s77, s77][s77, 1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None - slice_3: "f32[2, s77][s77, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, 0, 2) - slice_4: "f32[s77, 1][s77, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 1, 3, 4); arg1_1 = None + copy_: "f32[s0, s0][s0, 1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None + slice_3: "f32[2, s0][s0, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, 0, 2) + slice_4: "f32[s0, 1][s0, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 1, 3, 4); arg1_1 = None return (slice_3, slice_4)""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, @@ -1470,18 +1470,18 @@ def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"): self.assertExpectedInline( graph_aot, """\ -def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"): - clone: "f32[s77][1]cpu" = torch.ops.aten.clone.default(arg1_1) +def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"): + clone: "f32[s0][1]cpu" = torch.ops.aten.clone.default(arg1_1) nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(clone); clone = None sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0) ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None _to_copy: "f32[u0, 1][1, u0]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_alias = True, _y_base_index = 1, _y_alias = True, _all_bases = [arg1_1, _to_copy]); _to_copy = None - getitem_1: "f32[s77][1]cpu" = auto_functionalized_v2[1] + getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1] getitem_2: "f32[u0, 1][1, u0]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None - copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None - alias_1: "f32[s77][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None + copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None + alias_1: "f32[s0][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None return (alias_1, slice_2)""", # noqa: B950 ignore_comments=True, @@ -1517,16 +1517,16 @@ def forward(self, arg0_1: "f32[2][1]cpu"): self.assertExpectedInline( graph_inductor, """\ -def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"): +def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"): nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(arg1_1) sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0) ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None convert_element_type: "f32[u0, 1][1, u0]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None - alias_default: "f32[s77][1]cpu" = torch.ops.aten.alias.default(arg1_1) + alias_default: "f32[s0][1]cpu" = torch.ops.aten.alias.default(arg1_1) alias_default_1: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.alias.default(convert_element_type) foo_default = torch.ops.mylib.foo.default(alias_default, alias_default_1); alias_default = alias_default_1 = foo_default = None - copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None + copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None return (arg1_1, slice_2)""", # noqa: B950 ignore_comments=True, diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index ad3986cd937..776afea317d 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -11822,7 +11822,7 @@ class CommonTemplate: # 'i1 + 3 * i0' is cached. self.assertTrue( "i0 + 2 * i1" in mul_buf.data.inner_fn_str() - or "i0 + i1 * s64" in mul_buf.data.inner_fn_str() + or "i0 + i1 * s1" in mul_buf.data.inner_fn_str() ) with add_scheduler_init_hook(hook_fn): @@ -12551,7 +12551,7 @@ class CommonTemplate: torch.testing.assert_close(ref, act, atol=1e-3, rtol=1e-3) if is_dynamic_shape_enabled(): - size_assert_pattern = r"assert_size_stride.[a-z]+[0-9]+, .2, 3, s12, s80, s80., .3\*s12\*s80\*s80, s12\*s80\*s80, 1, s12\*s80, s1.." # noqa: B950 + size_assert_pattern = r"assert_size_stride.[a-z]+[0-9]+, .2, 3, s1, s2, s2., .3\*s1\*s2\*s2, s1\*s2\*s2, 1, s1\*s2, s1.." # noqa: B950 else: size_assert_pattern = r"assert_size_stride.[a-z]+[0-9]+, .2, 3, 16, 32, 32., .49152, 16384, 1, 512, 16.." FileCheck().check_regex(size_assert_pattern).run(code) @@ -12570,8 +12570,8 @@ class CommonTemplate: code = run_and_get_triton_code(f, x) if is_dynamic_shape_enabled(): - FileCheck().check("assert_size_stride(buf1, (s77, s27), (s27, 1))").check( - "assert_size_stride(buf2, (s77, s27), (s27, 1))" + FileCheck().check("assert_size_stride(buf1, (s0, s1), (s1, 1))").check( + "assert_size_stride(buf2, (s0, s1), (s1, 1))" ).run(code) else: FileCheck().check("assert_size_stride(buf1, (16, 32), (32, 1))").check( diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index a3458efbe65..27a54d83cbd 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -393,7 +393,7 @@ class TestPySymInt(TestCase): self.assertEqual(res_and, 0b1000) self.assertIsInstance(res_and, torch.SymInt, msg=type(res_and)) self.assertExpectedInline( - str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_and(s97, s26), 8)""" + str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_and(s0, s1), 8)""" ) a1 = create_symint(shape_env, 3) @@ -415,7 +415,7 @@ class TestPySymInt(TestCase): self.assertEqual(res_or, 0b1110) self.assertIsInstance(res_or, torch.SymInt, msg=type(res_or)) self.assertExpectedInline( - str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_or(s97, s26), 14)""" + str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_or(s0, s1), 14)""" ) def test_stride(self): @@ -497,7 +497,7 @@ class TestPySymInt(TestCase): shape_env = ShapeEnv() a0 = create_symint(shape_env, 2) self.assertEqual(guard_int(a0), 2) - self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""") + self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""") def test_sym_sum(self): shape_env = ShapeEnv() @@ -512,7 +512,7 @@ class TestPySymInt(TestCase): shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True) s0 = create_symint(shape_env, 2) self.assertEqual(guard_int(s0), 2) - self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""") + self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""") shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True) s0 = create_symint(shape_env, 2) @@ -520,7 +520,7 @@ class TestPySymInt(TestCase): self.assertEqual(len(shape_env.guards), 0) self.assertExpectedInline( str([ra.expr for ra in shape_env.deferred_runtime_asserts[None]]), - """[Eq(s97, 2)]""", + """[Eq(s0, 2)]""", ) def test_sym_int(self): @@ -529,14 +529,14 @@ class TestPySymInt(TestCase): r = sym_int(a0) self.assertEqual(r, 5) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 5)""") + self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 5)""") a1 = create_symint(shape_env, 7) r = sym_int(a1 / 2) self.assertEqual(guard_int(r), 3) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s26, 2)), 3)""" + str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s1, 2)), 3)""" ) a3 = create_symint(shape_env, 3) @@ -544,7 +544,7 @@ class TestPySymInt(TestCase): self.assertEqual(guard_int(r), 6) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s57)), 6)""" + str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s2)), 6)""" ) def test_sym_log2(self): @@ -554,7 +554,7 @@ class TestPySymInt(TestCase): self.assertEqual(r, 2.0) self.assertIsInstance(r, torch.SymFloat, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_log2(ToFloat(s97)), 2.0)""" + str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_log2(ToFloat(s0)), 2.0)""" ) def test_sym_sqrt(self): @@ -564,7 +564,7 @@ class TestPySymInt(TestCase): self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymFloat, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(ToFloat(s97)), 2.0)""" + str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(ToFloat(s0)), 2.0)""" ) def test_sym_floor(self): @@ -575,14 +575,14 @@ class TestPySymInt(TestCase): self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( str(shape_env.guards[0][0]), - """Eq(FloorToInt(IntTrueDiv(s97, 2)), 2)""", + """Eq(FloorToInt(IntTrueDiv(s0, 2)), 2)""", ) r = math.floor(3.0 * a0) self.assertEqual(r, 15) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( str(shape_env.guards[1][0]), - """Eq(FloorToInt(3.0*ToFloat(s97)), 15)""", + """Eq(FloorToInt(3.0*ToFloat(s0)), 15)""", ) def test_sym_trunc(self): @@ -592,14 +592,14 @@ class TestPySymInt(TestCase): self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s97, 2)), 2)""" + str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s0, 2)), 2)""" ) r = torch.sym_int(torch.sym_sqrt(a0)) self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( str(shape_env.guards[1][0]), - """Eq(TruncToInt(OpaqueUnaryFn_sqrt(ToFloat(s97))), 2)""", + """Eq(TruncToInt(OpaqueUnaryFn_sqrt(ToFloat(s0))), 2)""", ) def test_sym_ceil(self): @@ -610,7 +610,7 @@ class TestPySymInt(TestCase): self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( str(shape_env.guards[0][0]), - """Eq(CeilToInt(IntTrueDiv(s97, 2)), 3)""", + """Eq(CeilToInt(IntTrueDiv(s0, 2)), 3)""", ) r1 = 3.0 * a0 r = math.floor(r1) @@ -618,7 +618,7 @@ class TestPySymInt(TestCase): self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( str(shape_env.guards[1][0]), - """Eq(FloorToInt(3.0*ToFloat(s97)), 15)""", + """Eq(FloorToInt(3.0*ToFloat(s0)), 15)""", ) def test_sym_ite(self): @@ -638,7 +638,7 @@ class TestPySymInt(TestCase): self.assertEqual(type(t), type(r3)) self.assertExpectedInline( str(shape_env.guards[0][0]), - """Eq(Piecewise((s97, Eq(s97, 5)), (s26, True)), 5)""", + """Eq(Piecewise((s0, Eq(s0, 5)), (s1, True)), 5)""", ) b4 = f == 5 r4 = torch.sym_ite(b4, t, f) @@ -647,7 +647,7 @@ class TestPySymInt(TestCase): self.assertEqual(type(f), type(r4)) self.assertExpectedInline( str(shape_env.guards[1][0]), - """Eq(Piecewise((s97, Eq(s26, 5)), (s26, True)), 4)""", + """Eq(Piecewise((s0, Eq(s1, 5)), (s1, True)), 4)""", ) def test_tracing_sym_ite(self): @@ -679,7 +679,7 @@ def forward(self, x_1): shape_env = ShapeEnv() a0 = create_symint(shape_env, 2) int(a0) - self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""") + self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""") def test_data_dependent_guard(self): shape_env = ShapeEnv() @@ -710,7 +710,7 @@ def forward(self, x_1): self.assertTrue(expect_true(i0 < s0)) self.assertExpectedInline( str([ra.expr for ra in shape_env.deferred_runtime_asserts[i0.node.expr]]), - """[u0 < s97]""", + """[u0 < s0]""", ) self.assertTrue(i0 < s0) self.assertTrue(i0 != s0) @@ -1173,18 +1173,18 @@ def forward(self, x_1): out.strip(), """\ class f(torch.nn.Module): - def forward(self, a_1: "f32[s75, s96]", b_1: "f32[s57, s96]"): + def forward(self, a_1: "f32[s0, s1]", b_1: "f32[s2, s1]"): # No stacktrace found for following nodes - sym_size_int: "Sym(s75)" = torch.ops.aten.sym_size.int(a_1, 0) - sym_size_int_1: "Sym(s57)" = torch.ops.aten.sym_size.int(b_1, 0) - add: "Sym(s57 + s75)" = sym_size_int + sym_size_int_1; sym_size_int = sym_size_int_1 = None - sym_size_int_2: "Sym(s96)" = torch.ops.aten.sym_size.int(a_1, 1) - sym_size_int_3: "Sym(s96)" = torch.ops.aten.sym_size.int(b_1, 1); b_1 = None - add_1: "Sym(2*s96)" = sym_size_int_2 + sym_size_int_3; sym_size_int_2 = sym_size_int_3 = None - new_empty: "f32[s57 + s75, 2*s96]" = torch.ops.aten.new_empty.default(a_1, [add, add_1], pin_memory = False); a_1 = add = add_1 = None + sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(a_1, 0) + sym_size_int_1: "Sym(s2)" = torch.ops.aten.sym_size.int(b_1, 0) + add: "Sym(s0 + s2)" = sym_size_int + sym_size_int_1; sym_size_int = sym_size_int_1 = None + sym_size_int_2: "Sym(s1)" = torch.ops.aten.sym_size.int(a_1, 1) + sym_size_int_3: "Sym(s1)" = torch.ops.aten.sym_size.int(b_1, 1); b_1 = None + add_1: "Sym(2*s1)" = sym_size_int_2 + sym_size_int_3; sym_size_int_2 = sym_size_int_3 = None + new_empty: "f32[s0 + s2, 2*s1]" = torch.ops.aten.new_empty.default(a_1, [add, add_1], pin_memory = False); a_1 = add = add_1 = None native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None - getitem: "f32[s57 + s75, 2*s96]" = native_dropout[0] - getitem_1: "b8[s57 + s75, 2*s96]" = native_dropout[1]; native_dropout = None + getitem: "f32[s0 + s2, 2*s1]" = native_dropout[0] + getitem_1: "b8[s0 + s2, 2*s1]" = native_dropout[1]; native_dropout = None return (getitem, getitem_1)""", # noqa: B950 ) @@ -2846,8 +2846,8 @@ class TestGuardsExpressions(TestCase): ], ) - self.assertEqual(f"{x.stride()}", "(s49, 1)") - self.assertEqual(f"{x.shape}", "torch.Size([s26, s49])") + self.assertEqual(f"{x.stride()}", "(s1, 1)") + self.assertEqual(f"{x.shape}", "torch.Size([s0, s1])") x_clean = _remove_symbols_without_guarding(x, 4096) diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 40c4750f04f..e5da232b2a5 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1084,7 +1084,7 @@ def forward(self, x_1, y_1): test_inputs.append([(6, 8)]) gm = self._test_dynamic(f, [(3, 4)], test_inputs) self.assertTrue(eval_guards(gm, torch.randn(4, 5))) - self.assertEqual(repr(bind_symbols(gm, torch.randn(4, 5))), "{s75: 4, s96: 5}") + self.assertEqual(repr(bind_symbols(gm, torch.randn(4, 5))), "{s0: 4, s1: 5}") self.assertFalse(eval_guards(gm, torch.randn(25, 5))) self.assertExpectedInline(show_guards(gm), """L['x'].size()[0] <= 19""") @@ -1717,7 +1717,7 @@ def forward(self, a_1): gm = self._test_dynamic(f, [(1, 6), (8, 1)], test_inputs) self.assertTrue(eval_guards(gm, torch.randn(1, 10), torch.randn(6, 1))) self.assertFalse(eval_guards(gm, torch.randn(1, 2), torch.randn(4, 1))) - self.assertExpectedInline(show_guards(gm), """2*L['b'].size()[0]*L['a'].size()[1] > 20""") + self.assertExpectedInline(show_guards(gm), """2*L['a'].size()[1]*L['b'].size()[0] > 20""") def test_new_empty(self): def f(a, b): diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index fa4443b1b5d..4fe1421bc06 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -869,12 +869,6 @@ def _optimized_add( if sortkey(lhs._args[0]) > sortkey(rhs._args[-1]): return make_optimized(rhs._args + lhs._args) - # (a1+a3) + (a0+a2) => (a0+a1+a2+a3) - new_args = list(lhs._args) - for a in rhs._args: - new_args = _binary_search_insert_arg(new_args, a) - return make_optimized(new_args) - # (a0+a2) + a1 => (a0+a1+a2) if lhs_is_optimized_summation and rhs.is_symbol: new_args = _binary_search_insert_arg(list(lhs._args), rhs) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 8a04dc0552a..ff20a582810 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -14,7 +14,6 @@ import atexit import collections import dis import functools -import hashlib import inspect import itertools import logging @@ -3290,11 +3289,6 @@ class ShapeEnv: self.guards: list[ShapeGuard] = [] self.axioms: dict[sympy.Expr, sympy.Expr] = {} - - # A set of ids that have already been allocated. This is used - # for when we allocate symbol ids using the hash of the source - # names to ensure we don't have collisions via linear probing - self.unique_ids: set[int] = set() # Maps symbolic ints to their original concrete values # Currently populated from tensors self.var_to_val: dict[sympy.Symbol, sympy.Integer] = {} @@ -4546,14 +4540,13 @@ class ShapeEnv: # If we're not duck shaping, we always create a new symbol # Even if we're duck shaping, if we haven't seen this particular # value before, we also create a new symbol - symbol_id = self._generate_unique_id(source.name()) if type(val) is int or is_nested_int(val): sympy_expr = make_symbol( - SymT.SIZE, symbol_id, positive=positive, integer=True + SymT.SIZE, len(self.var_to_val), positive=positive, integer=True ) else: sympy_expr = make_symbol( - SymT.FLOAT, symbol_id, positive=positive, real=True + SymT.FLOAT, len(self.var_to_val), positive=positive, real=True ) self.source_to_var[source_name] = sympy_expr # We always associate vars to vals @@ -6565,13 +6558,6 @@ class ShapeEnv: sloc, _ = self._get_stack_summary(framework_loc=framework_loc) return sloc - def _generate_unique_id(self, source_name: str) -> int: - attempt = int(hashlib.sha256(source_name.encode()).hexdigest(), 16) % 100 - while attempt in self.unique_ids: - attempt += 1 - self.unique_ids.add(attempt) - return attempt - def _find_frame_locals(self) -> _FrameLocalResult: """ Given the current user code frame, finds the relevant lines of code,