pytorch/torch/distributed/_spmd/api.py
Chien-Chin Huang 250c054bdd [SPMD] Pull the minimal working distribute API and SPMD module to PyTorch (#94802)
Pull the minimal working distribute API and SPMD module to PyTorch. The original code is on https://github.com/pytorch/tau/tree/main/spmd/compiler.

Other main contributors to the original code base: @anj-s, @lessw2020, @wanchaol @aazzolini

Differential Revision: [D43197230](https://our.internmc.facebook.com/intern/diff/D43197230/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94802
Approved by: https://github.com/anj-s, https://github.com/wanchaol
2023-02-16 00:36:16 +00:00

54 lines
1.8 KiB
Python

from typing import Dict, Optional, Sequence, Tuple
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._spmd.distribute import distribute, Schema
from torch.distributed._spmd.distributed_graph import DistributedGraph
from torch.distributed._tensor import Placement, Replicate
class SPMD(nn.Module):
def __init__(
self,
module: nn.Module,
schema: Schema,
input_schemas: Sequence[Placement] = tuple(),
) -> None:
"""
Given a non-distributed nn.Module, distribute the module and apply
optimizations over the distributed module (fx.GraphModule).
Args:
module (nn.Module): The target module.
schema (Schema): The distributed schema.
input_schemas (Sequence[Placement]): The schemas of the inputs.
"""
super().__init__()
assert schema.placements == [
Replicate()
], "SPMD only support Replicate() parameters for now"
# TODO: Fix model initialization with coalescing.
# This needs to happen post model transformation.
# Consider an explicit model init API.
for p in module.parameters():
dist.broadcast(p, src=0)
self._param_schema = schema
self._input_schemas = input_schemas
self._compiled_m: Optional[nn.Module] = None
self._dist_graph = DistributedGraph(orig_module=module)
def forward(self, *args: Tuple[object], **kwargs: Dict[str, object]) -> object:
if self._compiled_m is None:
self._compiled_m = distribute(
self._dist_graph,
self._param_schema,
self._input_schemas,
*args,
**kwargs,
)
assert self._compiled_m is not None
return self._compiled_m(*args, **kwargs)