mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] Fix constant folder (#166655)
Fixes https://fb.workplace.com/groups/1028545332188949/permalink/1351999569843522/ where the resulting graph of constant folder uses a sym node which has been created later. Graph diff: https://www.internalfb.com/intern/diffing/?paste_number=2014609054 Before: ``` %full_65 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([%sym_size_int_47, 768], 1), kwargs = {dtype: torch.int64, layout: torch.strided, device: cuda:0, pin_memory: False}) %select_18 : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%full_65, 1, 0), kwargs = {}) %mul_2792 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%select_18, 0), kwargs = {}) %embedding_4 : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%_uv__surface_embeddings_weight, %mul_2792), kwargs = {}) ``` After: ``` %full_65 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([%sym_size_int_47, 768], 1), kwargs = {dtype: torch.int64, layout: torch.strided, device: cuda:0, pin_memory: False}) %full_default_1 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([%sym_size_int_150], 0), kwargs = {dtype: torch.int64, layout: torch.strided, device: cuda:0, pin_memory: False}) %embedding_4 : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%_uv__surface_embeddings_weight, %full_default_1), kwargs = {}) ... %sym_size_int_150 : [num_users=7] = call_function[target=torch.ops.aten.sym_size.int](args = (%view_193, 0), kwargs = {}) ``` I couldn't figure out a small repro for this :/ Pull Request resolved: https://github.com/pytorch/pytorch/pull/166655 Approved by: https://github.com/eellison
This commit is contained in:
parent
b9bcb37f40
commit
984e64b2cd
|
|
@ -227,7 +227,8 @@ class UniformValueConstantFolder(ConstantFolder):
|
|||
self.symint_nodes = _SymHashingDict()
|
||||
for n in self.module.graph.nodes: # type: ignore[union-attr]
|
||||
if "val" in n.meta and isinstance(n.meta["val"], torch.SymInt):
|
||||
self.symint_nodes[n.meta["val"]] = n
|
||||
if n.meta["val"] not in self.symint_nodes:
|
||||
self.symint_nodes[n.meta["val"]] = n
|
||||
|
||||
# reference from torch/_funtorch/partitioners.py:get_default_op_list
|
||||
self.view_op_packets = [
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user