[fbia] Keep Track of full qualified name before and after remote sharding (#83889)

Summary: track qualname changes in embedding sharding & FX split, and compose target qualname in the end of FBIA transform stage, so we can use the qualname mapping in XL materialize stage

Test Plan:
CI/CD

with DISABLE_XLEBB_MATERIALIZATION = True
https://fburl.com/fblearner/a8yljbux

with DISABLE_XLEBB_MATERIALIZATION = False
https://fburl.com/fblearner/2nvi0dam

Reviewed By: lliu315gt

Differential Revision: D38772525

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83889
Approved by: https://github.com/houseroad
This commit is contained in:
Nan Xiao 2022-08-24 01:15:25 +00:00 committed by PyTorch MergeBot
parent 58f61d50a4
commit c47e0450f8

View File

@ -26,6 +26,7 @@ from .tools_common import (
)
import warnings
__all__ = ['FxNetAccNodesFinder', 'FxNetSplitterInternalError', 'Subgraph', 'SplitResult', 'generate_inputs_for_submodules']
_LOGGER = logging.getLogger(__name__)
@ -315,11 +316,21 @@ class _SplitterBase:
self.update_deps_for_fusions()
self.non_acc_submodule_name = non_acc_submodule_name
self._node_submodule_map: Dict[str, str] = {}
# ===============================================================
# Helpers for ctor and initial state
# ===============================================================
def get_node_submodule_map(self) -> Dict[str, str]:
""" Returns a map from node name to submodule name, e.g.
node: main_module_impl_impl_over_arch_unary_multiple_embedding
_pooling_embedding_pooling_sparse_entity_equivalence_key
_proxy_embedding_bag
maps to submodule name of: _run_on_acc_1
"""
return self._node_submodule_map
def find_deps(self) -> Dict[torch.fx.Node, NodeSet]:
"""
Builds a graph of node dependencies. Leaf nodes don't have any
@ -715,12 +726,9 @@ class _SplitterBase:
current_cpu_nodes, current_acc_nodes = self.starter_nodes()
visited_nodes: NodeSet = set()
# Determine which subgraph to start from based on node dependency
acc_subgraph: bool = True
for n in current_cpu_nodes:
if self.deps[n] <= visited_nodes:
acc_subgraph = False
break
# Determine which subgraph to start from based on which subgraph has
# 0-dep node
acc_subgraph: bool = not any([len(self.deps[n]) == 0 for n in current_cpu_nodes])
current_subgraph_nodes: NodeList = []
@ -809,14 +817,14 @@ class _SplitterBase:
def tag(self, subgraphs: List[Subgraph]):
self.tags: List[str] = []
for subgraph in subgraphs:
subgraph_name = self.non_acc_submodule_name
tag = f"_run_on_acc_{len(self.tags)}" if subgraph.is_acc else f"{self.non_acc_submodule_name}{len(self.tags)}"
self.tags.append(tag)
for node in subgraph.nodes:
if hasattr(node, "tag"):
raise FxNetSplitterInternalError(f"Node {node} was already tagged")
node.tag = tag # type: ignore[attr-defined]
self._node_submodule_map[node.name] = tag
def split(self, remove_tag: bool = False) -> torch.fx.GraphModule:
split_module = split_by_tags(self.module, self.tags)