typing distributed.py (#160365)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160365
Approved by: https://github.com/StrongerXi
ghstack dependencies: #160362, #160363, #160364
This commit is contained in:
Lucas Kabela 2025-08-13 15:40:10 -07:00 committed by PyTorch MergeBot
parent 9faca5f260
commit 453cfa5153

View File

@ -1,5 +1,3 @@
# mypy: ignore-errors
"""
This module implements distributed training optimizations for TorchDynamo backends.
@ -21,11 +19,12 @@ of compilation.
import logging
import traceback
from dataclasses import dataclass, field
from typing import Any, Optional
from typing import Any, Callable, Optional
from unittest import mock
import torch
from torch import fx
from torch._dynamo.backends.registry import CompiledFn, CompilerFn
from torch._dynamo.output_graph import GraphCompileReason
from torch._dynamo.utils import deepcopy_to_fake_tensor, detect_fake_mode
from torch._logging import trace_structured
@ -39,7 +38,7 @@ log = logging.getLogger(__name__)
ddp_graph_log = torch._logging.getArtifactLogger(__name__, "ddp_graphs")
def args_str(args):
def args_str(args: Any) -> str:
# a debug helper
if torch.is_tensor(args):
return f"T[{args.shape}]"
@ -58,7 +57,7 @@ class Bucket:
nodes: list[fx.Node] = field(default_factory=list)
# param_ids is just used for unit testing
param_ids: list = field(default_factory=list)
param_ids: list[int] = field(default_factory=list)
# keep track of any buckets that were extended for logging purposes
opcount_increased_to_capture_external_output: int = 0
@ -78,9 +77,9 @@ def bucket_has_external_output(bucket: Bucket) -> bool:
return False
def pretty_print_buckets(buckets: list[Bucket], bucket_bytes_cap: int):
def pretty_print_buckets(buckets: list[Bucket], bucket_bytes_cap: int) -> None:
headers = ("Index", "Size (b)", "Param Names")
rows = []
rows: list[tuple[Optional[int], Optional[int], str]] = []
extended_buckets = []
for idx, bucket in enumerate(reversed(buckets)):
if len(bucket.params) > 0:
@ -136,7 +135,7 @@ def pretty_print_buckets(buckets: list[Bucket], bucket_bytes_cap: int):
log.debug("DDPOptimizer captured no parameters and did not split this graph.")
def has_higher_order_op(gm):
def has_higher_order_op(gm: fx.GraphModule) -> bool:
# Check if there is a higher order op in the graph
for node in gm.graph.nodes:
if node.op == "get_attr":
@ -146,7 +145,7 @@ def has_higher_order_op(gm):
return False
def propagate_metadata(orig_gm, split_gm) -> None:
def propagate_metadata(orig_gm: fx.GraphModule, split_gm: fx.GraphModule) -> None:
for name, module in split_gm.named_modules():
if "." not in name and len(name):
# TODO: add split id to CompileId: https://github.com/pytorch/tlparse/pull/83/files#r1880649384
@ -154,7 +153,7 @@ def propagate_metadata(orig_gm, split_gm) -> None:
module._param_name_to_source = orig_gm._param_name_to_source
def propagate_dynamo_source(orig_gm, split_gm) -> None:
def propagate_dynamo_source(orig_gm: fx.GraphModule, split_gm: fx.GraphModule) -> None:
name_to_dynamo_source = {}
for node in orig_gm.graph.find_nodes(op="placeholder"):
name_to_dynamo_source[node.name] = node._dynamo_source
@ -168,12 +167,19 @@ def propagate_dynamo_source(orig_gm, split_gm) -> None:
# compile each of the partitioned submodules using the user-provided compiler
class SubmodCompiler(torch.fx.interpreter.Interpreter):
def __init__(self, module, compiler, fake_mode) -> None:
def __init__(
self,
module: fx.GraphModule,
compiler: CompilerFn,
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode,
) -> None:
super().__init__(module)
self.compiler = compiler
self.fake_mode = fake_mode
def compile_submod(self, input_mod, args, kwargs):
def compile_submod(
self, input_mod: fx.GraphModule, args: list[torch.Tensor], kwargs: Any
) -> Any:
"""
Compile the submodule,
using a wrapper to make sure its output is always a tuple,
@ -182,12 +188,14 @@ class SubmodCompiler(torch.fx.interpreter.Interpreter):
assert len(kwargs) == 0, "We assume only args for these modules"
class WrapperModule(torch.nn.Module):
def __init__(self, submod, unwrap_singleton_tuple) -> None:
def __init__(
self, submod: Callable[..., Any], unwrap_singleton_tuple: bool
) -> None:
super().__init__()
self.submod = submod
self.unwrap_singleton_tuple = unwrap_singleton_tuple
def forward(self, *args):
def forward(self, *args: Any) -> Any:
x = self.submod(*args)
# TODO(whc)
# for some reason the isinstance check is necessary if I split one node per submod
@ -205,12 +213,12 @@ class SubmodCompiler(torch.fx.interpreter.Interpreter):
sn.args = (sn.args,)
input_mod.recompile()
input_mod.compile_subgraph_reason = GraphCompileReason(
input_mod.compile_subgraph_reason = GraphCompileReason( # type: ignore[assignment]
"DDPOptimizer intentional graph-break (See Note [DDPOptimizer])."
" Set `torch._dynamo.config.optimize_ddp = False` to disable.",
[
# it's close to useless to get a real stacktrace here, and quite verbose.
traceback.FrameSummary(__file__, 0, DDPOptimizer),
traceback.FrameSummary(__file__, 0, "DDPOptimizer"),
],
)
@ -257,7 +265,7 @@ class SubmodCompiler(torch.fx.interpreter.Interpreter):
assert isinstance(kwargs, dict)
if n.op == "call_module":
real_mod = self.fetch_attr(n.target)
real_mod = self.fetch_attr(str(n.target))
if self.fake_mode:
curr_submod = deepcopy_to_fake_tensor(real_mod, self.fake_mode)
else:
@ -287,10 +295,10 @@ class SubmodCompiler(torch.fx.interpreter.Interpreter):
def __init__(self) -> None:
self.tc = torch._guards.TracingContext.try_get()
assert self.tc
torch._guards.TracingContext.try_get().fakify_first_call = True
self.tc.fakify_first_call = True
def __del__(self) -> None:
self.tc.fakify_first_call = False
self.tc.fakify_first_call = False # type: ignore[union-attr]
# For aot_eager and other backends, tracing context is not set
has_tracing_context = torch._guards.TracingContext.try_get() is not None
@ -308,9 +316,9 @@ class SubmodCompiler(torch.fx.interpreter.Interpreter):
# We update the original (outer) graph with a call into the compiled module
# instead of the uncompiled one.
self.module.delete_submodule(n.target)
n.target = "compiled_" + n.target
self.module.add_submodule(n.target, compiled_submod_real)
self.module.delete_submodule(n.target) # type: ignore[operator]
n.target = "compiled_" + n.target # type: ignore[operator]
self.module.add_submodule(n.target, compiled_submod_real) # type: ignore[operator]
# Finally, we have to produce inputs for use compiling the next submodule,
# and these need to be FakeTensors, so we execute the module under fake_mode
@ -398,7 +406,7 @@ class DDPOptimizer:
def __init__(
self,
bucket_bytes_cap: int,
backend_compile_fn,
backend_compile_fn: CompilerFn,
first_bucket_cap: Optional[int] = None,
) -> None:
if first_bucket_cap is not None:
@ -416,21 +424,27 @@ class DDPOptimizer:
self.backend_compile_fn = backend_compile_fn
def _ignore_parameter(self, parameter):
def _ignore_parameter(self, parameter: torch.nn.Parameter) -> bool:
return hasattr(parameter, "_ddp_ignored") and parameter._ddp_ignored
def add_param(self, bucket, param, name):
def add_param(self, bucket: Bucket, param: torch.nn.Parameter, name: str) -> None:
bucket.size += param.untyped_storage().nbytes()
bucket.params.append(name)
bucket.param_ids.append(id(param))
def add_module_params_to_bucket(self, mod, bucket, processed_modules, prefix):
def add_module_params_to_bucket(
self,
mod: torch.nn.Module,
bucket: Bucket,
processed_modules: set[torch.nn.Module],
prefix: str,
) -> None:
processed_modules.add(mod)
for name, param in mod.named_parameters():
if param.requires_grad and not self._ignore_parameter(param):
self.add_param(bucket, param, f"{prefix}_{name}")
def add_param_args(self, bucket, node):
def add_param_args(self, bucket: Bucket, node: fx.Node) -> None:
for arg in node.args:
if not isinstance(arg, torch.fx.node.Node):
continue
@ -442,9 +456,11 @@ class DDPOptimizer:
and param.requires_grad
and not self._ignore_parameter(param)
):
self.add_param(bucket, param, arg.target)
self.add_param(bucket, param, str(arg.target))
def compile_fn(self, gm: fx.GraphModule, example_inputs: list[torch.Tensor]):
def compile_fn(
self, gm: fx.GraphModule, example_inputs: list[torch.Tensor]
) -> CompiledFn:
"""
Implements graph splitting, first determining a set of of buckets by counting
parameter sizes in reverse graph order, then invoking the user/backend compiler
@ -453,7 +469,7 @@ class DDPOptimizer:
"""
# 1: compute the partition map according to DDP bucket logic
buckets = [Bucket()] # (size, param_names)
processed_modules = set()
processed_modules: set[torch.nn.Module] = set()
for node in reversed(gm.graph.nodes):
if node.op in ("output", "placeholder"):
continue
@ -533,7 +549,9 @@ class DDPOptimizer:
partition_map[node] = idx
split_gm = fx.passes.split_module.split_module(
gm, None, lambda node: partition_map[node]
gm,
None, # type: ignore[arg-type]
lambda node: partition_map[node],
)
# See note [Assumption on Dynamo Metadata]