mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
typing distributed.py (#160365)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160365 Approved by: https://github.com/StrongerXi ghstack dependencies: #160362, #160363, #160364
This commit is contained in:
parent
9faca5f260
commit
453cfa5153
|
|
@ -1,5 +1,3 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
This module implements distributed training optimizations for TorchDynamo backends.
|
||||
|
||||
|
|
@ -21,11 +19,12 @@ of compilation.
|
|||
import logging
|
||||
import traceback
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
from torch._dynamo.backends.registry import CompiledFn, CompilerFn
|
||||
from torch._dynamo.output_graph import GraphCompileReason
|
||||
from torch._dynamo.utils import deepcopy_to_fake_tensor, detect_fake_mode
|
||||
from torch._logging import trace_structured
|
||||
|
|
@ -39,7 +38,7 @@ log = logging.getLogger(__name__)
|
|||
ddp_graph_log = torch._logging.getArtifactLogger(__name__, "ddp_graphs")
|
||||
|
||||
|
||||
def args_str(args):
|
||||
def args_str(args: Any) -> str:
|
||||
# a debug helper
|
||||
if torch.is_tensor(args):
|
||||
return f"T[{args.shape}]"
|
||||
|
|
@ -58,7 +57,7 @@ class Bucket:
|
|||
nodes: list[fx.Node] = field(default_factory=list)
|
||||
|
||||
# param_ids is just used for unit testing
|
||||
param_ids: list = field(default_factory=list)
|
||||
param_ids: list[int] = field(default_factory=list)
|
||||
|
||||
# keep track of any buckets that were extended for logging purposes
|
||||
opcount_increased_to_capture_external_output: int = 0
|
||||
|
|
@ -78,9 +77,9 @@ def bucket_has_external_output(bucket: Bucket) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def pretty_print_buckets(buckets: list[Bucket], bucket_bytes_cap: int):
|
||||
def pretty_print_buckets(buckets: list[Bucket], bucket_bytes_cap: int) -> None:
|
||||
headers = ("Index", "Size (b)", "Param Names")
|
||||
rows = []
|
||||
rows: list[tuple[Optional[int], Optional[int], str]] = []
|
||||
extended_buckets = []
|
||||
for idx, bucket in enumerate(reversed(buckets)):
|
||||
if len(bucket.params) > 0:
|
||||
|
|
@ -136,7 +135,7 @@ def pretty_print_buckets(buckets: list[Bucket], bucket_bytes_cap: int):
|
|||
log.debug("DDPOptimizer captured no parameters and did not split this graph.")
|
||||
|
||||
|
||||
def has_higher_order_op(gm):
|
||||
def has_higher_order_op(gm: fx.GraphModule) -> bool:
|
||||
# Check if there is a higher order op in the graph
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "get_attr":
|
||||
|
|
@ -146,7 +145,7 @@ def has_higher_order_op(gm):
|
|||
return False
|
||||
|
||||
|
||||
def propagate_metadata(orig_gm, split_gm) -> None:
|
||||
def propagate_metadata(orig_gm: fx.GraphModule, split_gm: fx.GraphModule) -> None:
|
||||
for name, module in split_gm.named_modules():
|
||||
if "." not in name and len(name):
|
||||
# TODO: add split id to CompileId: https://github.com/pytorch/tlparse/pull/83/files#r1880649384
|
||||
|
|
@ -154,7 +153,7 @@ def propagate_metadata(orig_gm, split_gm) -> None:
|
|||
module._param_name_to_source = orig_gm._param_name_to_source
|
||||
|
||||
|
||||
def propagate_dynamo_source(orig_gm, split_gm) -> None:
|
||||
def propagate_dynamo_source(orig_gm: fx.GraphModule, split_gm: fx.GraphModule) -> None:
|
||||
name_to_dynamo_source = {}
|
||||
for node in orig_gm.graph.find_nodes(op="placeholder"):
|
||||
name_to_dynamo_source[node.name] = node._dynamo_source
|
||||
|
|
@ -168,12 +167,19 @@ def propagate_dynamo_source(orig_gm, split_gm) -> None:
|
|||
|
||||
# compile each of the partitioned submodules using the user-provided compiler
|
||||
class SubmodCompiler(torch.fx.interpreter.Interpreter):
|
||||
def __init__(self, module, compiler, fake_mode) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
module: fx.GraphModule,
|
||||
compiler: CompilerFn,
|
||||
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode,
|
||||
) -> None:
|
||||
super().__init__(module)
|
||||
self.compiler = compiler
|
||||
self.fake_mode = fake_mode
|
||||
|
||||
def compile_submod(self, input_mod, args, kwargs):
|
||||
def compile_submod(
|
||||
self, input_mod: fx.GraphModule, args: list[torch.Tensor], kwargs: Any
|
||||
) -> Any:
|
||||
"""
|
||||
Compile the submodule,
|
||||
using a wrapper to make sure its output is always a tuple,
|
||||
|
|
@ -182,12 +188,14 @@ class SubmodCompiler(torch.fx.interpreter.Interpreter):
|
|||
assert len(kwargs) == 0, "We assume only args for these modules"
|
||||
|
||||
class WrapperModule(torch.nn.Module):
|
||||
def __init__(self, submod, unwrap_singleton_tuple) -> None:
|
||||
def __init__(
|
||||
self, submod: Callable[..., Any], unwrap_singleton_tuple: bool
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.submod = submod
|
||||
self.unwrap_singleton_tuple = unwrap_singleton_tuple
|
||||
|
||||
def forward(self, *args):
|
||||
def forward(self, *args: Any) -> Any:
|
||||
x = self.submod(*args)
|
||||
# TODO(whc)
|
||||
# for some reason the isinstance check is necessary if I split one node per submod
|
||||
|
|
@ -205,12 +213,12 @@ class SubmodCompiler(torch.fx.interpreter.Interpreter):
|
|||
sn.args = (sn.args,)
|
||||
|
||||
input_mod.recompile()
|
||||
input_mod.compile_subgraph_reason = GraphCompileReason(
|
||||
input_mod.compile_subgraph_reason = GraphCompileReason( # type: ignore[assignment]
|
||||
"DDPOptimizer intentional graph-break (See Note [DDPOptimizer])."
|
||||
" Set `torch._dynamo.config.optimize_ddp = False` to disable.",
|
||||
[
|
||||
# it's close to useless to get a real stacktrace here, and quite verbose.
|
||||
traceback.FrameSummary(__file__, 0, DDPOptimizer),
|
||||
traceback.FrameSummary(__file__, 0, "DDPOptimizer"),
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -257,7 +265,7 @@ class SubmodCompiler(torch.fx.interpreter.Interpreter):
|
|||
assert isinstance(kwargs, dict)
|
||||
|
||||
if n.op == "call_module":
|
||||
real_mod = self.fetch_attr(n.target)
|
||||
real_mod = self.fetch_attr(str(n.target))
|
||||
if self.fake_mode:
|
||||
curr_submod = deepcopy_to_fake_tensor(real_mod, self.fake_mode)
|
||||
else:
|
||||
|
|
@ -287,10 +295,10 @@ class SubmodCompiler(torch.fx.interpreter.Interpreter):
|
|||
def __init__(self) -> None:
|
||||
self.tc = torch._guards.TracingContext.try_get()
|
||||
assert self.tc
|
||||
torch._guards.TracingContext.try_get().fakify_first_call = True
|
||||
self.tc.fakify_first_call = True
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.tc.fakify_first_call = False
|
||||
self.tc.fakify_first_call = False # type: ignore[union-attr]
|
||||
|
||||
# For aot_eager and other backends, tracing context is not set
|
||||
has_tracing_context = torch._guards.TracingContext.try_get() is not None
|
||||
|
|
@ -308,9 +316,9 @@ class SubmodCompiler(torch.fx.interpreter.Interpreter):
|
|||
|
||||
# We update the original (outer) graph with a call into the compiled module
|
||||
# instead of the uncompiled one.
|
||||
self.module.delete_submodule(n.target)
|
||||
n.target = "compiled_" + n.target
|
||||
self.module.add_submodule(n.target, compiled_submod_real)
|
||||
self.module.delete_submodule(n.target) # type: ignore[operator]
|
||||
n.target = "compiled_" + n.target # type: ignore[operator]
|
||||
self.module.add_submodule(n.target, compiled_submod_real) # type: ignore[operator]
|
||||
|
||||
# Finally, we have to produce inputs for use compiling the next submodule,
|
||||
# and these need to be FakeTensors, so we execute the module under fake_mode
|
||||
|
|
@ -398,7 +406,7 @@ class DDPOptimizer:
|
|||
def __init__(
|
||||
self,
|
||||
bucket_bytes_cap: int,
|
||||
backend_compile_fn,
|
||||
backend_compile_fn: CompilerFn,
|
||||
first_bucket_cap: Optional[int] = None,
|
||||
) -> None:
|
||||
if first_bucket_cap is not None:
|
||||
|
|
@ -416,21 +424,27 @@ class DDPOptimizer:
|
|||
|
||||
self.backend_compile_fn = backend_compile_fn
|
||||
|
||||
def _ignore_parameter(self, parameter):
|
||||
def _ignore_parameter(self, parameter: torch.nn.Parameter) -> bool:
|
||||
return hasattr(parameter, "_ddp_ignored") and parameter._ddp_ignored
|
||||
|
||||
def add_param(self, bucket, param, name):
|
||||
def add_param(self, bucket: Bucket, param: torch.nn.Parameter, name: str) -> None:
|
||||
bucket.size += param.untyped_storage().nbytes()
|
||||
bucket.params.append(name)
|
||||
bucket.param_ids.append(id(param))
|
||||
|
||||
def add_module_params_to_bucket(self, mod, bucket, processed_modules, prefix):
|
||||
def add_module_params_to_bucket(
|
||||
self,
|
||||
mod: torch.nn.Module,
|
||||
bucket: Bucket,
|
||||
processed_modules: set[torch.nn.Module],
|
||||
prefix: str,
|
||||
) -> None:
|
||||
processed_modules.add(mod)
|
||||
for name, param in mod.named_parameters():
|
||||
if param.requires_grad and not self._ignore_parameter(param):
|
||||
self.add_param(bucket, param, f"{prefix}_{name}")
|
||||
|
||||
def add_param_args(self, bucket, node):
|
||||
def add_param_args(self, bucket: Bucket, node: fx.Node) -> None:
|
||||
for arg in node.args:
|
||||
if not isinstance(arg, torch.fx.node.Node):
|
||||
continue
|
||||
|
|
@ -442,9 +456,11 @@ class DDPOptimizer:
|
|||
and param.requires_grad
|
||||
and not self._ignore_parameter(param)
|
||||
):
|
||||
self.add_param(bucket, param, arg.target)
|
||||
self.add_param(bucket, param, str(arg.target))
|
||||
|
||||
def compile_fn(self, gm: fx.GraphModule, example_inputs: list[torch.Tensor]):
|
||||
def compile_fn(
|
||||
self, gm: fx.GraphModule, example_inputs: list[torch.Tensor]
|
||||
) -> CompiledFn:
|
||||
"""
|
||||
Implements graph splitting, first determining a set of of buckets by counting
|
||||
parameter sizes in reverse graph order, then invoking the user/backend compiler
|
||||
|
|
@ -453,7 +469,7 @@ class DDPOptimizer:
|
|||
"""
|
||||
# 1: compute the partition map according to DDP bucket logic
|
||||
buckets = [Bucket()] # (size, param_names)
|
||||
processed_modules = set()
|
||||
processed_modules: set[torch.nn.Module] = set()
|
||||
for node in reversed(gm.graph.nodes):
|
||||
if node.op in ("output", "placeholder"):
|
||||
continue
|
||||
|
|
@ -533,7 +549,9 @@ class DDPOptimizer:
|
|||
partition_map[node] = idx
|
||||
|
||||
split_gm = fx.passes.split_module.split_module(
|
||||
gm, None, lambda node: partition_map[node]
|
||||
gm,
|
||||
None, # type: ignore[arg-type]
|
||||
lambda node: partition_map[node],
|
||||
)
|
||||
|
||||
# See note [Assumption on Dynamo Metadata]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user