mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This PR adds a option to preserve node types for the entire graph, this could allow some exploration about using those node types to do things like act checkpoint, etc. Pull Request resolved: https://github.com/pytorch/pytorch/pull/100072 Approved by: https://github.com/mrshenli
217 lines
7.8 KiB
Python
217 lines
7.8 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.utils._pytree as pytree
|
|
from torch._subclasses import FakeTensorMode
|
|
from torch.distributed._spmd.data_parallel import (
|
|
DataParallelStyle,
|
|
partition_data_parallel,
|
|
)
|
|
from torch.distributed._spmd.distribute import _convert_to_distributed, Schema
|
|
from torch.distributed._tensor import DeviceMesh, Placement, Replicate, Shard
|
|
|
|
from torch.fx import GraphModule
|
|
|
|
|
|
class ParallelMode(ABC):
|
|
"""
|
|
Basic Parallel Mode interface. Each parallelism pattern should implement
|
|
this interface to describe how to partition and compile the graph in the
|
|
spmd compiler.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def partition(
|
|
self,
|
|
gm: GraphModule,
|
|
model: torch.nn.Module,
|
|
optimizer: Optional[torch.optim.Optimizer],
|
|
params_and_buffers: Dict[str, Any],
|
|
named_states: Dict[str, Any],
|
|
args: Tuple[Any, ...],
|
|
kwargs: Dict[str, Any],
|
|
) -> GraphModule:
|
|
"""
|
|
Partition a single device graph to a distributed graph.
|
|
|
|
TODO(@wanchaol): some of these arguments are not necessary for
|
|
partitioning, remove the unnecessary ones later.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def transform_and_compile(self, gm: GraphModule) -> GraphModule:
|
|
"""
|
|
Transform and compile a distributed graph with a set of graph
|
|
transformation and optimization passes for each parallel mode.
|
|
|
|
The returned result should be a compiled executable graph in
|
|
the distributed environment.
|
|
"""
|
|
# TODO: add more necessary arguments to this interface.
|
|
raise NotImplementedError()
|
|
|
|
|
|
class DataParallel(ParallelMode):
|
|
"""Data Parallelism mode."""
|
|
|
|
def __init__(
|
|
self,
|
|
parallel_style: str = "replicate",
|
|
*,
|
|
input_batch_dim: int = 0,
|
|
custom_passes: Optional[Callable[[GraphModule], GraphModule]] = None,
|
|
):
|
|
"""
|
|
DataParallel Mode that partition the model and graph to data parallel style
|
|
parallelism (i.e. DDP/FSDP/ZERO-3). It currently supports three different
|
|
parallel styles: "replicate", "fully_shard", and "default". See
|
|
:class:`DataParallelStyle` for more details.
|
|
|
|
Args:
|
|
parallel_style (str): parallel style to use. Currently supports
|
|
"replicate", "fully_shard", and "default".
|
|
|
|
Keyword args:
|
|
input_batch_dim (int): the batch dimension of the input tensor.
|
|
default: 0
|
|
custom_passes (Callable[[GraphModule], GraphModule], optional):
|
|
A custom callable that overrides the default graph transformation
|
|
and optimization passes.
|
|
"""
|
|
if parallel_style == "replicate":
|
|
self.parallel_style = DataParallelStyle.REPLICATE
|
|
elif parallel_style == "fully_shard":
|
|
self.parallel_style = DataParallelStyle.FULLY_SHARD
|
|
elif parallel_style == "default":
|
|
self.parallel_style = DataParallelStyle.DEFAULT
|
|
else:
|
|
raise RuntimeError(f"Unknown parallel style: {parallel_style}")
|
|
|
|
# TODO: what if user passes in a incorrect `input_batch_dim`, how should we
|
|
# detect that and do proper error handling?
|
|
self.input_batch_dim = input_batch_dim
|
|
|
|
if custom_passes is not None:
|
|
self._gm_passes: Callable[[GraphModule], GraphModule] = custom_passes
|
|
else:
|
|
# TODO: add a few default passes here.
|
|
self._gm_passes = lambda gm: gm
|
|
|
|
def partition(
|
|
self,
|
|
gm: GraphModule,
|
|
model: torch.nn.Module,
|
|
optimizer: Optional[torch.optim.Optimizer],
|
|
params_and_buffers: Dict[str, Any],
|
|
named_states: Dict[str, Any],
|
|
args: Tuple[Any, ...],
|
|
kwargs: Dict[str, Any],
|
|
) -> GraphModule:
|
|
# TODO: figure out a way to avoid explicit "cuda" mesh.
|
|
mesh = DeviceMesh("cuda", torch.arange(dist.get_world_size()))
|
|
|
|
gm = partition_data_parallel(
|
|
gm,
|
|
model,
|
|
optimizer,
|
|
params_and_buffers,
|
|
named_states,
|
|
args,
|
|
kwargs,
|
|
mesh,
|
|
self.parallel_style,
|
|
self.input_batch_dim,
|
|
)
|
|
return gm
|
|
|
|
def transform_and_compile(self, gm: GraphModule) -> GraphModule:
|
|
"""optimize a distributed graph with a set of optimization passes"""
|
|
# TODO: add more necessary arguments to this interface.
|
|
return self._gm_passes(gm)
|
|
|
|
|
|
class DTensorExpandMode(ParallelMode):
|
|
"""
|
|
The DTensor Expand mode. It's replicating the parameters and
|
|
shard the inputs to represent DDP like behavior, it's currently
|
|
a transitent mode before we move to the new data parallel expansion.
|
|
"""
|
|
|
|
def __init__(
|
|
self, custom_passes: Optional[Callable[[GraphModule], GraphModule]] = None
|
|
):
|
|
self._placements_override: Dict[int, List[Placement]] = {}
|
|
if custom_passes is not None:
|
|
self._gm_passes: Callable[[GraphModule], GraphModule] = custom_passes
|
|
else:
|
|
# TODO: add a few default passes here.
|
|
self._gm_passes = lambda gm: gm
|
|
|
|
def partition(
|
|
self,
|
|
gm: GraphModule,
|
|
model: torch.nn.Module,
|
|
optimizer: Optional[torch.optim.Optimizer],
|
|
params_and_buffers: Dict[str, Any],
|
|
named_states: Dict[str, Any],
|
|
args: Tuple[Any, ...],
|
|
kwargs: Dict[str, Any],
|
|
) -> GraphModule:
|
|
flat_args, _ = pytree.tree_flatten(list(args) + list(kwargs.values()))
|
|
|
|
mesh = DeviceMesh("cuda", torch.arange(dist.get_world_size()).cuda())
|
|
shard_schema: Schema = Schema(mesh=mesh, placements=[Shard(0)])
|
|
# FIXME: allow other sharding schemas
|
|
replicate_schema: Schema = Schema(mesh=mesh, placements=[Replicate()])
|
|
|
|
inps, schemas = [], []
|
|
|
|
for p in pytree.tree_flatten(params_and_buffers)[0]:
|
|
assert isinstance(p, torch.Tensor), f"expecting Tensor but got {type(p)}"
|
|
inps.append(p)
|
|
schemas.append(replicate_schema)
|
|
|
|
for o in pytree.tree_flatten(named_states)[0]:
|
|
if isinstance(o, torch.Tensor):
|
|
inps.append(o)
|
|
schemas.append(replicate_schema)
|
|
else:
|
|
inps.append(torch.empty(0))
|
|
schemas.append(replicate_schema)
|
|
|
|
for a in flat_args:
|
|
if isinstance(a, torch.Tensor):
|
|
inps.append(a)
|
|
if id(a) in self._placements_override:
|
|
schemas.append(
|
|
Schema(mesh=mesh, placements=self._placements_override[id(a)])
|
|
)
|
|
else:
|
|
schemas.append(shard_schema)
|
|
else:
|
|
# Create dummy tensor and schema for non-tensor inputs for
|
|
# the purpose of dtensor expansion. Non-tensor inputs are
|
|
# guaranteed unused in dispatcher graphs produced by make_fx.
|
|
# However, we still need to respect them so that tensor inputs
|
|
# match wtih their placeholders.
|
|
inps.append(torch.empty(0))
|
|
schemas.append(shard_schema)
|
|
|
|
with FakeTensorMode(allow_non_fake_inputs=True):
|
|
fake_inps = [torch.empty_like(inp) for inp in inps]
|
|
|
|
return _convert_to_distributed(
|
|
gm, fake_inps, schemas, default_mesh=mesh, _allow_partial=False
|
|
)[0]
|
|
|
|
def transform_and_compile(self, gm: GraphModule) -> GraphModule:
|
|
"""
|
|
Transform and compile a distributed graph with a set of graph transformation
|
|
and optimization passes for the dtensor fallback parallel mode.
|
|
"""
|
|
# TODO: move the trasnformation passed to this function
|
|
return self._gm_passes(gm)
|