mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
preserve custom meta in placeholders (#149661)
Fixes #147338 Differential Revision: [D71573533](https://our.internmc.facebook.com/intern/diff/D71573533/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/149661 Approved by: https://github.com/junpeiz, https://github.com/angelayi
This commit is contained in:
parent
0eb3ac9349
commit
09aa63ea2c
|
|
@ -11848,6 +11848,36 @@ def forward(self, x):
|
|||
return (getitem_3, cos_1)""",
|
||||
)
|
||||
|
||||
def test_run_decompositions_keep_metadata(self):
|
||||
"""Make sure the metadata is kept after exported program run_decompositions."""
|
||||
|
||||
@torch.library.custom_op("mylib::add", mutates_args=())
|
||||
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
...
|
||||
|
||||
@torch.library.register_fake("mylib::add")
|
||||
def _(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
return torch.ops.mylib.add(x, y)
|
||||
|
||||
model = TestModel()
|
||||
x_example = torch.randn(2, 3)
|
||||
y_example = torch.randn(2, 3)
|
||||
exported_program = export(model, (x_example, y_example))
|
||||
|
||||
for node in exported_program.graph.nodes:
|
||||
node.meta["custom"] = {"my_field": "dummy"}
|
||||
|
||||
for node in exported_program.graph.nodes:
|
||||
self.assertEqual(node.meta["custom"]["my_field"], "dummy")
|
||||
|
||||
decomposed_program = exported_program.run_decompositions()
|
||||
for node in decomposed_program.graph.nodes:
|
||||
self.assertEqual(node.meta["custom"]["my_field"], "dummy")
|
||||
|
||||
def test_export_linear_preserve_dynamic_shape(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -844,6 +844,12 @@ def placeholder_naming_pass(
|
|||
These are named token, token_1, ...
|
||||
"""
|
||||
|
||||
custom_meta: dict[str, Any] = {}
|
||||
if isinstance(mod, torch.fx.GraphModule):
|
||||
for node in mod.graph.nodes:
|
||||
if "custom" in node.meta:
|
||||
custom_meta[node.name] = node.meta["custom"]
|
||||
|
||||
def _strip_name(x):
|
||||
if x.startswith("L__self___"):
|
||||
x = x[len("L__self___") :]
|
||||
|
|
@ -918,6 +924,8 @@ def placeholder_naming_pass(
|
|||
if node.op == "placeholder":
|
||||
assert node.name in name_map
|
||||
node.name = node.target = name_map[node.name]
|
||||
if node.name in custom_meta:
|
||||
node.meta["custom"] = custom_meta[node.name]
|
||||
# if the constant obj is an input, we also need to update meta["val"]
|
||||
# because this is created before the placeholder naming pass
|
||||
if isinstance(node.meta["val"], CustomObjArgument):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user