mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[fbia] Keep Track of full qualified name before and after remote sharding (#83889)
Summary: track qualname changes in embedding sharding & FX split, and compose target qualname in the end of FBIA transform stage, so we can use the qualname mapping in XL materialize stage Test Plan: CI/CD with DISABLE_XLEBB_MATERIALIZATION = True https://fburl.com/fblearner/a8yljbux with DISABLE_XLEBB_MATERIALIZATION = False https://fburl.com/fblearner/2nvi0dam Reviewed By: lliu315gt Differential Revision: D38772525 Pull Request resolved: https://github.com/pytorch/pytorch/pull/83889 Approved by: https://github.com/houseroad
This commit is contained in:
parent
58f61d50a4
commit
c47e0450f8
|
|
@ -26,6 +26,7 @@ from .tools_common import (
|
|||
)
|
||||
import warnings
|
||||
|
||||
|
||||
__all__ = ['FxNetAccNodesFinder', 'FxNetSplitterInternalError', 'Subgraph', 'SplitResult', 'generate_inputs_for_submodules']
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -315,11 +316,21 @@ class _SplitterBase:
|
|||
self.update_deps_for_fusions()
|
||||
|
||||
self.non_acc_submodule_name = non_acc_submodule_name
|
||||
self._node_submodule_map: Dict[str, str] = {}
|
||||
|
||||
# ===============================================================
|
||||
# Helpers for ctor and initial state
|
||||
# ===============================================================
|
||||
|
||||
def get_node_submodule_map(self) -> Dict[str, str]:
|
||||
""" Returns a map from node name to submodule name, e.g.
|
||||
node: main_module_impl_impl_over_arch_unary_multiple_embedding
|
||||
_pooling_embedding_pooling_sparse_entity_equivalence_key
|
||||
_proxy_embedding_bag
|
||||
maps to submodule name of: _run_on_acc_1
|
||||
"""
|
||||
return self._node_submodule_map
|
||||
|
||||
def find_deps(self) -> Dict[torch.fx.Node, NodeSet]:
|
||||
"""
|
||||
Builds a graph of node dependencies. Leaf nodes don't have any
|
||||
|
|
@ -715,12 +726,9 @@ class _SplitterBase:
|
|||
current_cpu_nodes, current_acc_nodes = self.starter_nodes()
|
||||
visited_nodes: NodeSet = set()
|
||||
|
||||
# Determine which subgraph to start from based on node dependency
|
||||
acc_subgraph: bool = True
|
||||
for n in current_cpu_nodes:
|
||||
if self.deps[n] <= visited_nodes:
|
||||
acc_subgraph = False
|
||||
break
|
||||
# Determine which subgraph to start from based on which subgraph has
|
||||
# 0-dep node
|
||||
acc_subgraph: bool = not any([len(self.deps[n]) == 0 for n in current_cpu_nodes])
|
||||
|
||||
current_subgraph_nodes: NodeList = []
|
||||
|
||||
|
|
@ -809,14 +817,14 @@ class _SplitterBase:
|
|||
def tag(self, subgraphs: List[Subgraph]):
|
||||
self.tags: List[str] = []
|
||||
for subgraph in subgraphs:
|
||||
subgraph_name = self.non_acc_submodule_name
|
||||
|
||||
tag = f"_run_on_acc_{len(self.tags)}" if subgraph.is_acc else f"{self.non_acc_submodule_name}{len(self.tags)}"
|
||||
self.tags.append(tag)
|
||||
for node in subgraph.nodes:
|
||||
if hasattr(node, "tag"):
|
||||
raise FxNetSplitterInternalError(f"Node {node} was already tagged")
|
||||
|
||||
node.tag = tag # type: ignore[attr-defined]
|
||||
self._node_submodule_map[node.name] = tag
|
||||
|
||||
def split(self, remove_tag: bool = False) -> torch.fx.GraphModule:
|
||||
split_module = split_by_tags(self.module, self.tags)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user