mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[PT] Update module partitioner to return parameter node (#101121)
Instead of returning param name, return parameter get_attr node. Differential Revision: [D45713916](https://our.internmc.facebook.com/intern/diff/D45713916/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/101121 Approved by: https://github.com/angelayi
This commit is contained in:
parent
75375b410d
commit
bec655f826
|
|
@ -44,7 +44,7 @@ class SourcePartition():
|
|||
output_nodes: List[Node] = field(default_factory=list)
|
||||
|
||||
# Parameters that are being used
|
||||
params: List[str] = field(default_factory=list)
|
||||
params: List[Node] = field(default_factory=list)
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
|
|
@ -92,7 +92,7 @@ def get_source_partitions(
|
|||
input_nodes.add(arg)
|
||||
|
||||
if node.op == "get_attr":
|
||||
params.add(node.target)
|
||||
params.add(node)
|
||||
|
||||
for user in node.users.keys():
|
||||
if user not in nodes:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user