mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fixes #113191 ``` pydocstyle torch/distributed/fsdp/fully_sharded_data_parallel.py --count ``` On master: 80 After my changes on this PR: 3 ``` pydocstyle torch/distributed/_spmd/comm_tensor.py --count ``` On master: 5 After my changes on this PR: 3 ``` pydocstyle torch/distributed/_spmd/experimental_ops.py --count ``` On master: 3 After my changes on this PR: 1 ``` pydocstyle torch/distributed/_spmd/iter_graph_module.py --count ``` On master: 39 After my changes on this PR: 27 ``` pydocstyle torch/distributed/_spmd/graph_utils.py --count ``` On master: 16 After my changes on this PR: 4 ``` pydocstyle torch/distributed/_spmd/distribute.py --count ``` On master: 19 After my changes on this PR: 10 ``` pydocstyle torch/distributed/_spmd/api.py --count ``` On master: 10 After my changes on this PR: 3 ``` pydocstyle torch/distributed/_spmd/batch_dim_utils.py --count ``` On master: 14 After my changes on this PR: 3 ``` pydocstyle torch/distributed/_spmd/data_parallel.py --count ``` On master: 34 After my changes on this PR: 2 ``` pydocstyle torch/distributed/_spmd/graph_optimization.py --count ``` On master: 35 After my changes on this PR: 13 Pull Request resolved: https://github.com/pytorch/pytorch/pull/113216 Approved by: https://github.com/ezyang
146 lines
4.8 KiB
Python
146 lines
4.8 KiB
Python
import logging
|
|
import os
|
|
import tempfile
|
|
from enum import Enum
|
|
from typing import Callable, cast, Dict, Iterable, List, Set
|
|
|
|
import torch.fx as fx
|
|
from torch.fx.passes.shape_prop import TensorMetadata
|
|
from torch.utils import _pytree as pytree
|
|
from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
|
|
|
|
logger: logging.Logger = logging.getLogger("graph_utils")
|
|
|
|
|
|
class OP(str, Enum):
|
|
CALL_FUNCTION = "call_function"
|
|
CALL_MODULE = "call_module"
|
|
CALL_METHOD = "call_method"
|
|
GET_ATTR = "get_attr"
|
|
OUTPUT = "output"
|
|
PLACEHOLDER = "placeholder"
|
|
|
|
|
|
class CommType(str, Enum):
|
|
ALLREDUCE = "allreduce_"
|
|
ALLGATHER = "allgather_"
|
|
BROADCAST = "broadcast_"
|
|
REDUCESCATTER = "reduce_scatter_"
|
|
SCATTER = "scatter_"
|
|
|
|
|
|
def get_node_tensor_metadata(node: fx.Node, is_required: bool = True) -> TensorMetadata:
|
|
metadata = node.meta.get("tensor_meta", None)
|
|
if is_required and metadata is None:
|
|
raise RuntimeError(
|
|
f"Callsite expects that ``tensor_meta`` exists in ``{node.name}``, "
|
|
f"but got None instead. Node: {node.op} {node.name} {node.target}"
|
|
)
|
|
return metadata
|
|
|
|
|
|
def get_output(graph: fx.Graph) -> fx.Node:
|
|
"""Take a graphmodule and return the graph output node.
|
|
|
|
We traverse in reverse to expedite it, with the idea that last node should be output
|
|
"""
|
|
for node in reversed(graph.nodes):
|
|
if node.op == OP.OUTPUT:
|
|
return node
|
|
raise RuntimeError(f"Cannot find the output node in {graph}")
|
|
|
|
|
|
def find_node(
|
|
graph: fx.Graph, predicate: Callable, reverse_order: bool = False
|
|
) -> List[fx.Node]:
|
|
"""Take a predicate and return all the nodes in the `graph` where the predicate holds."""
|
|
nodes = cast(Iterable[fx.Node], graph.nodes)
|
|
if reverse_order:
|
|
nodes = cast(Iterable[fx.Node], iter(reversed(nodes))) # type: ignore[call-overload]
|
|
return [node for node in nodes if predicate(node)]
|
|
|
|
|
|
def is_leaf_subgraph(graph: fx.Graph, subgraph: List[fx.Node]) -> bool:
|
|
"""Ensure nodes in ``subgraph`` satisfy one of the following rules.
|
|
|
|
1. The user of the node is in ``subgraph``.
|
|
2. The user of the node is output.
|
|
3. There are no users -- the node is a side-effect node.
|
|
"""
|
|
all_nodes: Set[fx.Node] = set(subgraph)
|
|
output = get_output(graph)
|
|
for node in subgraph:
|
|
for user in node.users:
|
|
if not isinstance(user, fx.Node):
|
|
continue
|
|
if user not in all_nodes and user != output:
|
|
return False
|
|
return True
|
|
|
|
|
|
def clone_subgraph(
|
|
graph: fx.Graph, subgraph: List[fx.Node], target: fx.Node
|
|
) -> List[fx.Node]:
|
|
"""Clone the given subgraph and insert it before ``target``.
|
|
|
|
This API currently does not support inserting after ``target``.
|
|
"""
|
|
all_nodes = set(subgraph)
|
|
mapping: Dict[fx.Node, fx.Node] = dict()
|
|
cloned_subgraph = []
|
|
with graph.inserting_before(target):
|
|
for node in subgraph:
|
|
cloned_node = graph.call_function(
|
|
node.target, node.args, node.kwargs, node.type
|
|
)
|
|
# TODO: there are many flatten/unflatten in IterGraph that
|
|
# can be simplified with tree_map. Will simplify this in
|
|
# a follow-up PR.
|
|
original_input = pytree.arg_tree_leaves(*node.args, **node.kwargs)
|
|
cloned_input, spec = tree_flatten((cloned_node.args, cloned_node.kwargs))
|
|
mapped_cloned_input = []
|
|
for original_input_node, cloned_input_node in zip(
|
|
original_input, cloned_input
|
|
):
|
|
if (
|
|
isinstance(original_input_node, fx.Node)
|
|
and original_input_node in all_nodes
|
|
):
|
|
assert original_input_node in mapping
|
|
mapped_cloned_input.append(mapping[original_input_node])
|
|
else:
|
|
mapped_cloned_input.append(cloned_input_node)
|
|
cloned_node.args, cloned_node.kwargs = tree_unflatten(
|
|
mapped_cloned_input, spec
|
|
)
|
|
mapping[node] = cloned_node
|
|
cloned_subgraph.append(cloned_node)
|
|
|
|
return cloned_subgraph
|
|
|
|
|
|
def rebuild_graph(gm: fx.GraphModule, remove_dead_code: bool = True) -> None:
|
|
"""Run the required steps to ensure production-ready graph.
|
|
|
|
Note - per the fx docs, elimination of dead code is not very precise.
|
|
Hence, the flag to make this step optional.
|
|
"""
|
|
gm.graph.lint()
|
|
if remove_dead_code:
|
|
gm.graph.eliminate_dead_code()
|
|
gm.recompile()
|
|
|
|
|
|
def dump_graphs_to_files(graphs: Dict[str, fx.GraphModule], folder: str = "") -> str:
|
|
if not folder:
|
|
folder = tempfile.mkdtemp()
|
|
|
|
for prefix, gm in graphs.items():
|
|
with open(os.path.join(folder, f"{prefix}.graph"), "w") as fp:
|
|
fp.write(str(gm))
|
|
|
|
logger.warning("Dump graphs to %s", folder)
|
|
|
|
return folder
|