from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Optional, Tuple import torch import torch.distributed as dist import torch.utils._pytree as pytree from torch._subclasses import FakeTensorMode from torch.distributed._spmd.data_parallel import ( DataParallelStyle, partition_data_parallel, ) from torch.distributed._spmd.distribute import _convert_to_distributed, Schema from torch.distributed._tensor import DeviceMesh, Placement, Replicate, Shard from torch.fx import GraphModule class ParallelMode(ABC): """ Basic Parallel Mode interface. Each parallelism pattern should implement this interface to describe how to partition and compile the graph in the spmd compiler. """ @abstractmethod def partition( self, gm: GraphModule, model: torch.nn.Module, optimizer: Optional[torch.optim.Optimizer], params_and_buffers: Dict[str, Any], named_states: Dict[str, Any], args: Tuple[Any, ...], kwargs: Dict[str, Any], ) -> GraphModule: """ Partition a single device graph to a distributed graph. TODO(@wanchaol): some of these arguments are not necessary for partitioning, remove the unnecessary ones later. """ raise NotImplementedError() @abstractmethod def transform_and_compile(self, gm: GraphModule) -> GraphModule: """ Transform and compile a distributed graph with a set of graph transformation and optimization passes for each parallel mode. The returned result should be a compiled executable graph in the distributed environment. """ # TODO: add more necessary arguments to this interface. raise NotImplementedError() class DataParallel(ParallelMode): """Data Parallelism mode.""" def __init__( self, parallel_style: str = "replicate", *, input_batch_dim: int = 0, custom_passes: Optional[Callable[[GraphModule], GraphModule]] = None, ): """ DataParallel Mode that partition the model and graph to data parallel style parallelism (i.e. DDP/FSDP/ZERO-3). It currently supports three different parallel styles: "replicate", "fully_shard", and "default". See :class:`DataParallelStyle` for more details. Args: parallel_style (str): parallel style to use. Currently supports "replicate", "fully_shard", and "default". Keyword args: input_batch_dim (int): the batch dimension of the input tensor. default: 0 custom_passes (Callable[[GraphModule], GraphModule], optional): A custom callable that overrides the default graph transformation and optimization passes. """ if parallel_style == "replicate": self.parallel_style = DataParallelStyle.REPLICATE elif parallel_style == "fully_shard": self.parallel_style = DataParallelStyle.FULLY_SHARD elif parallel_style == "default": self.parallel_style = DataParallelStyle.DEFAULT else: raise RuntimeError(f"Unknown parallel style: {parallel_style}") # TODO: what if user passes in a incorrect `input_batch_dim`, how should we # detect that and do proper error handling? self.input_batch_dim = input_batch_dim if custom_passes is not None: self._gm_passes: Callable[[GraphModule], GraphModule] = custom_passes else: # TODO: add a few default passes here. self._gm_passes = lambda gm: gm def partition( self, gm: GraphModule, model: torch.nn.Module, optimizer: Optional[torch.optim.Optimizer], params_and_buffers: Dict[str, Any], named_states: Dict[str, Any], args: Tuple[Any, ...], kwargs: Dict[str, Any], ) -> GraphModule: # TODO: figure out a way to avoid explicit "cuda" mesh. mesh = DeviceMesh("cuda", torch.arange(dist.get_world_size())) gm = partition_data_parallel( gm, model, optimizer, params_and_buffers, named_states, args, kwargs, mesh, self.parallel_style, self.input_batch_dim, ) return gm def transform_and_compile(self, gm: GraphModule) -> GraphModule: """optimize a distributed graph with a set of optimization passes""" # TODO: add more necessary arguments to this interface. return self._gm_passes(gm) class DTensorExpandMode(ParallelMode): """ The DTensor Expand mode. It's replicating the parameters and shard the inputs to represent DDP like behavior, it's currently a transitent mode before we move to the new data parallel expansion. """ def __init__( self, custom_passes: Optional[Callable[[GraphModule], GraphModule]] = None ): self._placements_override: Dict[int, List[Placement]] = {} if custom_passes is not None: self._gm_passes: Callable[[GraphModule], GraphModule] = custom_passes else: # TODO: add a few default passes here. self._gm_passes = lambda gm: gm def partition( self, gm: GraphModule, model: torch.nn.Module, optimizer: Optional[torch.optim.Optimizer], params_and_buffers: Dict[str, Any], named_states: Dict[str, Any], args: Tuple[Any, ...], kwargs: Dict[str, Any], ) -> GraphModule: flat_args = pytree.arg_tree_leaves(*args, **kwargs) mesh = DeviceMesh("cuda", torch.arange(dist.get_world_size()).cuda()) shard_schema: Schema = Schema(mesh=mesh, placements=[Shard(0)]) # FIXME: allow other sharding schemas replicate_schema: Schema = Schema(mesh=mesh, placements=[Replicate()]) inps, schemas = [], [] for p in pytree.tree_leaves(params_and_buffers): assert isinstance(p, torch.Tensor), f"expecting Tensor but got {type(p)}" inps.append(p) schemas.append(replicate_schema) for o in pytree.tree_leaves(named_states): if isinstance(o, torch.Tensor): inps.append(o) schemas.append(replicate_schema) else: inps.append(torch.empty(0)) schemas.append(replicate_schema) for a in flat_args: if isinstance(a, torch.Tensor): inps.append(a) if id(a) in self._placements_override: schemas.append( Schema(mesh=mesh, placements=self._placements_override[id(a)]) ) else: schemas.append(shard_schema) else: # Create dummy tensor and schema for non-tensor inputs for # the purpose of dtensor expansion. Non-tensor inputs are # guaranteed unused in dispatcher graphs produced by make_fx. # However, we still need to respect them so that tensor inputs # match wtih their placeholders. inps.append(torch.empty(0)) schemas.append(shard_schema) with FakeTensorMode(allow_non_fake_inputs=True): fake_inps = [torch.empty_like(inp) for inp in inps] return _convert_to_distributed( gm, fake_inps, schemas, default_mesh=mesh, _allow_partial=False )[0] def transform_and_compile(self, gm: GraphModule) -> GraphModule: """ Transform and compile a distributed graph with a set of graph transformation and optimization passes for the dtensor fallback parallel mode. """ # TODO: move the trasnformation passed to this function return self._gm_passes(gm)