recovering node source from dict (#158373) (#158473)

Summary:

this diff recovers NodeSource object from its dict representation, which is crucial for NodeSource serde.

Test Plan:
ci

Rollback Plan:

Differential Revision: D78434648

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158473
Approved by: https://github.com/angelayi
This commit is contained in:
Songhao Jia 2025-07-17 17:00:14 +00:00 committed by PyTorch MergeBot
parent bfe5674e22
commit 6d31d38965
2 changed files with 76 additions and 6 deletions

View File

@ -32,6 +32,8 @@ class TestFXNodeSource(TestCase):
dummy_source_dict,
)
self.assertEqual(node_source, NodeSource._from_dict(node_source.to_dict()))
# Dummy node
node = torch.fx.Node(
graph=torch.fx.Graph(),
@ -179,14 +181,28 @@ class TestFXNodeSource(TestCase):
if node_name_1 in same_ancestor_nodes
else None,
}:
self.assertTrue(
node_name_to_from_node[node_name_1]
== node_name_to_from_node[node_name_2]
self.assertEqual(
node_name_to_from_node[node_name_1],
node_name_to_from_node[node_name_2],
)
self.assertEqual(
[
NodeSource._from_dict(ns.to_dict())
for ns in node_name_to_from_node[node_name_1]
],
node_name_to_from_node[node_name_2],
)
else:
self.assertTrue(
node_name_to_from_node[node_name_1]
!= node_name_to_from_node[node_name_2]
self.assertNotEqual(
node_name_to_from_node[node_name_1],
node_name_to_from_node[node_name_2],
)
self.assertNotEqual(
[
NodeSource._from_dict(ns.to_dict())
for ns in node_name_to_from_node[node_name_1]
],
node_name_to_from_node[node_name_2],
)
gm = ep.module()

View File

@ -153,6 +153,60 @@ class NodeSource:
return hash(_make_hashable(self.to_dict()))
@classmethod
def _from_dict(cls, d: Optional[dict]) -> Optional["NodeSource"]:
"""
Recursively deserialize from_node metadata from dictionary data.
It is used to deserialize the from_node field from serialized metadata.
Please use constructor NodeSource(node, ...) to create a NodeSource object.
"""
if d is None:
return None
assert isinstance(d, dict), f"Expected a dict, got {type(d)}"
# Create a NodeSource object directly without going through the constructor
# to avoid issues with graph ID and node creation
node_source = NodeSource.__new__(NodeSource)
# Reset the cached properties
node_source._action_string = None
node_source._dict = None
# Set the basic attributes
node_source.pass_name = d.get("pass_name", "")
# Parse action string back to NodeSourceAction enum list
action_str = d.get("action", "")
actions = []
if action_str:
for action_name in action_str.split("+"):
if action_name.upper() == "CREATE":
actions.append(NodeSourceAction.CREATE)
elif action_name.upper() == "REPLACE":
actions.append(NodeSourceAction.REPLACE)
node_source.action = actions
# Create the NodeInfo object directly
if "name" in d and "target" in d and "graph_id" in d:
node_info = NodeSource.NodeInfo(
d.get("name", ""), d.get("target", ""), d.get("graph_id", -1)
)
node_source.node_info = node_info
else:
node_source.node_info = None
# Recursively deserialize nested from_node
if d.get("from_node", None) is not None:
node_source.from_node = [
result
for fn in d.get("from_node", [])
if (result := cls._from_dict(fn)) is not None
]
else:
node_source.from_node = []
return node_source
@compatibility(is_backward_compatible=False)
@contextmanager