mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: this diff recovers NodeSource object from its dict representation, which is crucial for NodeSource serde. Test Plan: ci Rollback Plan: Differential Revision: D78434648 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158473 Approved by: https://github.com/angelayi
This commit is contained in:
parent
bfe5674e22
commit
6d31d38965
|
|
@ -32,6 +32,8 @@ class TestFXNodeSource(TestCase):
|
|||
dummy_source_dict,
|
||||
)
|
||||
|
||||
self.assertEqual(node_source, NodeSource._from_dict(node_source.to_dict()))
|
||||
|
||||
# Dummy node
|
||||
node = torch.fx.Node(
|
||||
graph=torch.fx.Graph(),
|
||||
|
|
@ -179,14 +181,28 @@ class TestFXNodeSource(TestCase):
|
|||
if node_name_1 in same_ancestor_nodes
|
||||
else None,
|
||||
}:
|
||||
self.assertTrue(
|
||||
node_name_to_from_node[node_name_1]
|
||||
== node_name_to_from_node[node_name_2]
|
||||
self.assertEqual(
|
||||
node_name_to_from_node[node_name_1],
|
||||
node_name_to_from_node[node_name_2],
|
||||
)
|
||||
self.assertEqual(
|
||||
[
|
||||
NodeSource._from_dict(ns.to_dict())
|
||||
for ns in node_name_to_from_node[node_name_1]
|
||||
],
|
||||
node_name_to_from_node[node_name_2],
|
||||
)
|
||||
else:
|
||||
self.assertTrue(
|
||||
node_name_to_from_node[node_name_1]
|
||||
!= node_name_to_from_node[node_name_2]
|
||||
self.assertNotEqual(
|
||||
node_name_to_from_node[node_name_1],
|
||||
node_name_to_from_node[node_name_2],
|
||||
)
|
||||
self.assertNotEqual(
|
||||
[
|
||||
NodeSource._from_dict(ns.to_dict())
|
||||
for ns in node_name_to_from_node[node_name_1]
|
||||
],
|
||||
node_name_to_from_node[node_name_2],
|
||||
)
|
||||
|
||||
gm = ep.module()
|
||||
|
|
|
|||
|
|
@ -153,6 +153,60 @@ class NodeSource:
|
|||
|
||||
return hash(_make_hashable(self.to_dict()))
|
||||
|
||||
@classmethod
|
||||
def _from_dict(cls, d: Optional[dict]) -> Optional["NodeSource"]:
|
||||
"""
|
||||
Recursively deserialize from_node metadata from dictionary data.
|
||||
It is used to deserialize the from_node field from serialized metadata.
|
||||
Please use constructor NodeSource(node, ...) to create a NodeSource object.
|
||||
"""
|
||||
if d is None:
|
||||
return None
|
||||
|
||||
assert isinstance(d, dict), f"Expected a dict, got {type(d)}"
|
||||
|
||||
# Create a NodeSource object directly without going through the constructor
|
||||
# to avoid issues with graph ID and node creation
|
||||
node_source = NodeSource.__new__(NodeSource)
|
||||
|
||||
# Reset the cached properties
|
||||
node_source._action_string = None
|
||||
node_source._dict = None
|
||||
|
||||
# Set the basic attributes
|
||||
node_source.pass_name = d.get("pass_name", "")
|
||||
|
||||
# Parse action string back to NodeSourceAction enum list
|
||||
action_str = d.get("action", "")
|
||||
actions = []
|
||||
if action_str:
|
||||
for action_name in action_str.split("+"):
|
||||
if action_name.upper() == "CREATE":
|
||||
actions.append(NodeSourceAction.CREATE)
|
||||
elif action_name.upper() == "REPLACE":
|
||||
actions.append(NodeSourceAction.REPLACE)
|
||||
node_source.action = actions
|
||||
|
||||
# Create the NodeInfo object directly
|
||||
if "name" in d and "target" in d and "graph_id" in d:
|
||||
node_info = NodeSource.NodeInfo(
|
||||
d.get("name", ""), d.get("target", ""), d.get("graph_id", -1)
|
||||
)
|
||||
node_source.node_info = node_info
|
||||
else:
|
||||
node_source.node_info = None
|
||||
|
||||
# Recursively deserialize nested from_node
|
||||
if d.get("from_node", None) is not None:
|
||||
node_source.from_node = [
|
||||
result
|
||||
for fn in d.get("from_node", [])
|
||||
if (result := cls._from_dict(fn)) is not None
|
||||
]
|
||||
else:
|
||||
node_source.from_node = []
|
||||
return node_source
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
@contextmanager
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user