[PT] Update module partitioner to return parameter node (#101121)

Instead of returning param name, return parameter get_attr node.

Differential Revision: [D45713916](https://our.internmc.facebook.com/intern/diff/D45713916/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101121
Approved by: https://github.com/angelayi
This commit is contained in:
Kimish Patel 2023-05-15 20:16:40 -07:00 committed by PyTorch MergeBot
parent 75375b410d
commit bec655f826

View File

@ -44,7 +44,7 @@ class SourcePartition():
output_nodes: List[Node] = field(default_factory=list)
# Parameters that are being used
params: List[str] = field(default_factory=list)
params: List[Node] = field(default_factory=list)
@compatibility(is_backward_compatible=False)
@ -92,7 +92,7 @@ def get_source_partitions(
input_nodes.add(arg)
if node.op == "get_attr":
params.add(node.target)
params.add(node)
for user in node.users.keys():
if user not in nodes: