preserve custom meta in placeholders (#149661)

Fixes #147338

Differential Revision: [D71573533](https://our.internmc.facebook.com/intern/diff/D71573533/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149661
Approved by: https://github.com/junpeiz, https://github.com/angelayi
This commit is contained in:
Avik Chaudhuri 2025-03-20 15:44:06 -07:00 committed by PyTorch MergeBot
parent 0eb3ac9349
commit 09aa63ea2c
2 changed files with 38 additions and 0 deletions

View File

@ -11848,6 +11848,36 @@ def forward(self, x):
return (getitem_3, cos_1)""",
)
def test_run_decompositions_keep_metadata(self):
"""Make sure the metadata is kept after exported program run_decompositions."""
@torch.library.custom_op("mylib::add", mutates_args=())
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
...
@torch.library.register_fake("mylib::add")
def _(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return torch.empty_like(x)
class TestModel(torch.nn.Module):
def forward(self, x, y):
return torch.ops.mylib.add(x, y)
model = TestModel()
x_example = torch.randn(2, 3)
y_example = torch.randn(2, 3)
exported_program = export(model, (x_example, y_example))
for node in exported_program.graph.nodes:
node.meta["custom"] = {"my_field": "dummy"}
for node in exported_program.graph.nodes:
self.assertEqual(node.meta["custom"]["my_field"], "dummy")
decomposed_program = exported_program.run_decompositions()
for node in decomposed_program.graph.nodes:
self.assertEqual(node.meta["custom"]["my_field"], "dummy")
def test_export_linear_preserve_dynamic_shape(self):
class M(torch.nn.Module):
def __init__(self):

View File

@ -844,6 +844,12 @@ def placeholder_naming_pass(
These are named token, token_1, ...
"""
custom_meta: dict[str, Any] = {}
if isinstance(mod, torch.fx.GraphModule):
for node in mod.graph.nodes:
if "custom" in node.meta:
custom_meta[node.name] = node.meta["custom"]
def _strip_name(x):
if x.startswith("L__self___"):
x = x[len("L__self___") :]
@ -918,6 +924,8 @@ def placeholder_naming_pass(
if node.op == "placeholder":
assert node.name in name_map
node.name = node.target = name_map[node.name]
if node.name in custom_meta:
node.meta["custom"] = custom_meta[node.name]
# if the constant obj is an input, we also need to update meta["val"]
# because this is created before the placeholder naming pass
if isinstance(node.meta["val"], CustomObjArgument):