From 984e64b2cddecfe43501a7e01a1cce0c25473f54 Mon Sep 17 00:00:00 2001 From: angelayi Date: Thu, 30 Oct 2025 22:51:25 +0000 Subject: [PATCH] [inductor] Fix constant folder (#166655) Fixes https://fb.workplace.com/groups/1028545332188949/permalink/1351999569843522/ where the resulting graph of constant folder uses a sym node which has been created later. Graph diff: https://www.internalfb.com/intern/diffing/?paste_number=2014609054 Before: ``` %full_65 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([%sym_size_int_47, 768], 1), kwargs = {dtype: torch.int64, layout: torch.strided, device: cuda:0, pin_memory: False}) %select_18 : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%full_65, 1, 0), kwargs = {}) %mul_2792 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%select_18, 0), kwargs = {}) %embedding_4 : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%_uv__surface_embeddings_weight, %mul_2792), kwargs = {}) ``` After: ``` %full_65 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([%sym_size_int_47, 768], 1), kwargs = {dtype: torch.int64, layout: torch.strided, device: cuda:0, pin_memory: False}) %full_default_1 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([%sym_size_int_150], 0), kwargs = {dtype: torch.int64, layout: torch.strided, device: cuda:0, pin_memory: False}) %embedding_4 : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%_uv__surface_embeddings_weight, %full_default_1), kwargs = {}) ... %sym_size_int_150 : [num_users=7] = call_function[target=torch.ops.aten.sym_size.int](args = (%view_193, 0), kwargs = {}) ``` I couldn't figure out a small repro for this :/ Pull Request resolved: https://github.com/pytorch/pytorch/pull/166655 Approved by: https://github.com/eellison --- torch/_inductor/fx_passes/joint_graph.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/fx_passes/joint_graph.py b/torch/_inductor/fx_passes/joint_graph.py index 8f4568bd89f..42d5479a34f 100644 --- a/torch/_inductor/fx_passes/joint_graph.py +++ b/torch/_inductor/fx_passes/joint_graph.py @@ -227,7 +227,8 @@ class UniformValueConstantFolder(ConstantFolder): self.symint_nodes = _SymHashingDict() for n in self.module.graph.nodes: # type: ignore[union-attr] if "val" in n.meta and isinstance(n.meta["val"], torch.SymInt): - self.symint_nodes[n.meta["val"]] = n + if n.meta["val"] not in self.symint_nodes: + self.symint_nodes[n.meta["val"]] = n # reference from torch/_funtorch/partitioners.py:get_default_op_list self.view_op_packets = [