mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
My first attempt was to apply the same solution as how proxy_tensor.py handles other inplace ops. However, foreach is different in the way that it's schema is `native_functions.yaml` does not return anything, whereas ops like `addcmul_` and `addcdiv_` do return Tensors (Thanks bdhirsh for teaching me this!). As a result, the proxy output during tracing does not wrap anything, and hence we cannot correctly connect it with subsequent operators. Modifying `native_functions.yaml` is not a preferred solution. After discussing with bdhirsh, the temporary solution is to do foreach functionalization as a graph pass for now. Later, when https://github.com/pytorch/pytorch/issues/97852 is addressed, we will switch to default functionalization. Edit: the latest version follows @bdhirsh 's suggestion on using `make_fx` `decomposition_table` instead of implementing manual fx.Graph tranforms to functionalize `_foreach_add_`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/97853 Approved by: https://github.com/fegin, https://github.com/wanchaol
395 lines
14 KiB
Python
395 lines
14 KiB
Python
from abc import ABC, abstractmethod
|
|
from contextlib import contextmanager, nullcontext
|
|
from copy import copy
|
|
from dataclasses import dataclass
|
|
from functools import wraps, partial
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
Optional,
|
|
Sequence,
|
|
Tuple,
|
|
Type,
|
|
Union,
|
|
cast,
|
|
)
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
import torch.utils._pytree as pytree
|
|
from torch import fx
|
|
from torch.distributed._spmd.distribute import (
|
|
_convert_to_distributed,
|
|
distribute,
|
|
Schema,
|
|
)
|
|
from torch.distributed._spmd.distributed_graph import DistributedGraph
|
|
from torch.distributed._tensor import (
|
|
DeviceMesh,
|
|
Placement,
|
|
Replicate,
|
|
Shard,
|
|
)
|
|
from torch.nn.utils import stateless
|
|
from functorch import make_fx
|
|
from torch.nn.utils._named_member_accessor import NamedMemberAccessor
|
|
|
|
|
|
class SPMD(nn.Module):
|
|
def __init__(
|
|
self,
|
|
module: nn.Module,
|
|
schema: Schema,
|
|
input_schemas: Sequence[Placement] = tuple(),
|
|
) -> None:
|
|
"""
|
|
Given a non-distributed nn.Module, distribute the module and apply
|
|
optimizations over the distributed module (fx.GraphModule).
|
|
|
|
Args:
|
|
module (nn.Module): The target module.
|
|
schema (Schema): The distributed schema.
|
|
input_schemas (Sequence[Placement]): The schemas of the inputs.
|
|
"""
|
|
super().__init__()
|
|
assert schema.placements == [
|
|
Replicate()
|
|
], "SPMD only support Replicate() parameters for now"
|
|
|
|
# TODO: Fix model initialization with coalescing.
|
|
# This needs to happen post model transformation.
|
|
# Consider an explicit model init API.
|
|
for p in module.parameters():
|
|
dist.broadcast(p, src=0)
|
|
|
|
self._param_schema = schema
|
|
self._input_schemas = input_schemas
|
|
self._compiled_m: Optional[nn.Module] = None
|
|
self._dist_graph = DistributedGraph(orig_module=module)
|
|
|
|
def forward(
|
|
self, *args: Tuple[object], **kwargs: Dict[str, object]
|
|
) -> object:
|
|
if self._compiled_m is None:
|
|
self._compiled_m = distribute(
|
|
self._dist_graph,
|
|
self._param_schema,
|
|
self._input_schemas,
|
|
*args,
|
|
**kwargs,
|
|
)
|
|
|
|
assert self._compiled_m is not None
|
|
return self._compiled_m(*args, **kwargs)
|
|
|
|
|
|
class Override(ABC):
|
|
r"""
|
|
Override the tracing and transformation behavior of :meth:`~torch.distributed._spmd.compile`.
|
|
This is useful when any part of the model is not traceable or if you prefer
|
|
to not trace it due to any reason. More specifically, users can implement
|
|
:meth:`torch.distributed._spmd.Override.replacement` to replace an original
|
|
submodule with the return new submodule. The new submodule contrains
|
|
operations that users preferred to be traced, which simply be a dummy
|
|
placeholder operator. After tracing, users can implement
|
|
:meth:`torch.distributed._spmd.Override.transform` to transform the traced
|
|
graph, where the dummy placeholder operator serves as an anchor to insert
|
|
new sub-graphs.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def replacement(self, orig_submodule: torch.nn.Module) -> torch.nn.Module:
|
|
r"""
|
|
Implement this method to return a new :class:`nn.Module` instance to
|
|
replace the ``orig_submodule`` argument in the model. This helps if
|
|
``orig_submodule`` is not traceable or should not be traced.
|
|
|
|
Args:
|
|
orig_submodule (class:`nn.Module`): original submodule instance to replace.
|
|
|
|
Returns:
|
|
A new :class:`nn.Module` instance to replace the original one.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def transform(
|
|
self, gm: fx.GraphModule, schema_map: Dict[str, Schema]
|
|
) -> fx.Graph:
|
|
r"""
|
|
Given a DTensor-expanded graph and shardig schema for every node,
|
|
conduct additional transformation for the sub-graph from the :class:`nn.Module`
|
|
returned by :meth:`torch.distributed._spmd.Override.replacement` if
|
|
necessary.
|
|
|
|
Args:
|
|
gm (:class:`fx.Graph`): a DTensor-expanded graph.
|
|
schema_map (Dict[str, :class:`Schema`]): a dictionary maps from node
|
|
name to DTensor schema.
|
|
|
|
Returns:
|
|
The :class:`fx.Graph` after transformation.
|
|
"""
|
|
pass
|
|
|
|
|
|
def _dtensor_expand(
|
|
gm: fx.GraphModule,
|
|
args: Tuple[Any, ...],
|
|
kwargs: Dict[str, Any],
|
|
named_states: Dict[str, Any],
|
|
params_and_buffers: Dict[str, Any],
|
|
) -> Tuple[fx.GraphModule, Dict[str, Schema]]:
|
|
flat_args, _ = pytree.tree_flatten(list(args) + list(kwargs.values()))
|
|
|
|
mesh = DeviceMesh("cuda", torch.arange(dist.get_world_size()).cuda())
|
|
shard_schema: Schema = Schema(mesh=mesh, placements=[Shard(0)])
|
|
# FIXME: allow other sharding schemas
|
|
replicate_schema: Schema = Schema(mesh=mesh, placements=[Replicate()])
|
|
|
|
inps, schemas = [], []
|
|
for a in flat_args:
|
|
if isinstance(a, torch.Tensor):
|
|
inps.append(a)
|
|
schemas.append(shard_schema)
|
|
elif isinstance(a, nn.Module) or isinstance(a, torch.optim.Optimizer):
|
|
# nn.Module or optimizer placeholder is captured by make_fx but
|
|
# never used in the graph
|
|
inps.append(torch.empty(0))
|
|
schemas.append(shard_schema)
|
|
|
|
for o in pytree.tree_flatten(named_states)[0]:
|
|
if isinstance(o, torch.Tensor):
|
|
inps.append(o)
|
|
schemas.append(replicate_schema)
|
|
else:
|
|
inps.append(torch.empty(0))
|
|
schemas.append(replicate_schema)
|
|
|
|
for p in pytree.tree_flatten(params_and_buffers)[0]:
|
|
assert isinstance(
|
|
p, torch.Tensor
|
|
), f"expecting Tensor but got {type(p)}"
|
|
inps.append(p)
|
|
schemas.append(replicate_schema)
|
|
|
|
return _convert_to_distributed(gm, inps, schemas, _allow_partial=False)
|
|
|
|
|
|
@contextmanager
|
|
def _rematerialize_optimizer(
|
|
opt: torch.optim.Optimizer,
|
|
named_states: Dict[str, Any],
|
|
params: Dict[str, nn.Parameter],
|
|
):
|
|
assert opt is not None
|
|
|
|
# update opt.state with proxy tensors
|
|
orig_states: Dict[str, Any] = copy(opt.state)
|
|
for n in named_states:
|
|
# opt.state's key type is string, but optimizer uses Parameter as keys
|
|
opt.state[params[n]] = named_states[n] # type: ignore[index]
|
|
|
|
# FIXME: support multiple parameter groups
|
|
param_group = opt.param_groups[0]
|
|
orig_params = param_group["params"]
|
|
# FIXME(@mrshenli): exclude buffers
|
|
param_group["params"] = params.values()
|
|
|
|
try:
|
|
yield
|
|
finally:
|
|
param_group["params"] = orig_params
|
|
opt.state.update(orig_states)
|
|
|
|
|
|
aten = torch.ops.aten # pyre-ignore
|
|
|
|
|
|
@contextmanager
|
|
def _enable_compile():
|
|
# The return value of torch._utils.is_compiling changes optimizer behavior.
|
|
# We need that function to return True to include optimizer in the graph.
|
|
# See: https://github.com/pytorch/pytorch/blob/a524123c91ab399c9dd6882c1189596dd77e7734/torch/optim/optimizer.py#L41
|
|
def f_true():
|
|
return True
|
|
|
|
orig_is_compiling_code = torch._utils.is_compiling.__code__
|
|
torch._utils.is_compiling.__code__ = f_true.__code__
|
|
try:
|
|
yield
|
|
finally:
|
|
torch._utils.is_compiling.__code__ = orig_is_compiling_code
|
|
|
|
|
|
@dataclass
|
|
class _CompiledResult:
|
|
gm: fx.GraphModule
|
|
mod: nn.Module
|
|
opt: Optional[torch.optim.Optimizer]
|
|
named_states: Dict[str, torch.Tensor]
|
|
params_and_buffers: Dict[str, torch.Tensor]
|
|
|
|
|
|
def _compile(
|
|
func: Callable,
|
|
module_override: Optional[Dict[Type[Any], Override]],
|
|
*args: Any,
|
|
**kwargs: Any,
|
|
) -> _CompiledResult:
|
|
# 1. Extract nn.Module and Optimizer from args and kwargs
|
|
# FIXME(@mrshenli): support multiple nn.Module instances
|
|
# FIXME(@mrshenli): support multiple Optiimzer instances
|
|
# FIXME(@mrshenli): need to broadcast model to sync parameters
|
|
mod, opt = None, None
|
|
for arg in pytree.tree_flatten(list(args) + list(kwargs.values()))[0]:
|
|
if isinstance(arg, nn.Module):
|
|
assert mod is None, "Only support single nn.Module for now"
|
|
mod = arg
|
|
if isinstance(arg, torch.optim.Optimizer):
|
|
assert opt is None, "Only support single Optimizer for now"
|
|
opt = arg
|
|
|
|
assert (
|
|
mod is not None
|
|
), "Couldn't find nn.Module instances from the arguments."
|
|
|
|
# 2. Override target submodules (e.g., MoE) with dummy replacements
|
|
if module_override:
|
|
accessor = NamedMemberAccessor(mod)
|
|
|
|
for typ, override in module_override.items():
|
|
for name, submodule in mod.named_modules():
|
|
if isinstance(submodule, typ):
|
|
accessor.swap_submodule(
|
|
name, override.replacement(submodule)
|
|
)
|
|
|
|
# 3. Trace statelss version of the train_step
|
|
params_and_buffers: Dict[str, Union[torch.Tensor, nn.Parameter]] = {
|
|
**dict(mod.named_parameters(remove_duplicate=False)),
|
|
**dict(mod.named_buffers(remove_duplicate=False)),
|
|
}
|
|
|
|
named_states = {}
|
|
if opt is not None:
|
|
opt_states, spec = pytree.tree_flatten(dict(opt.state))
|
|
|
|
# Pass named_states instead of opt.state to stateless_func, because
|
|
# the later uses nn.Parameter as key. During tracing, we need to
|
|
# make sure optimizers can find the states using proxy tensors.
|
|
for n, p in params_and_buffers.items():
|
|
if p in opt.state:
|
|
# opt.state's key type is string, but optimizer uses
|
|
# Parameter as keys
|
|
named_states[n] = opt.state[p] # type: ignore[index]
|
|
|
|
# Lift states and parameters as function arguments so that make_fx
|
|
# can trace operations applied to them.
|
|
def stateless_func(func, args, kwargs, named_states, params_and_buffers):
|
|
with stateless._reparametrize_module(
|
|
cast(nn.Module, mod), params_and_buffers
|
|
), _rematerialize_optimizer(
|
|
opt, named_states, params_and_buffers
|
|
) if opt else nullcontext():
|
|
ret = func(*args, **kwargs)
|
|
# make sure updated parameters are returned
|
|
return ret, list(mod.parameters()) # type: ignore[union-attr]
|
|
|
|
# FIXME: Using symbolic tracing to work around. Otherwise it hits
|
|
# shape mismatch error, as we use local inputs to trace local graph
|
|
# and use DTensor to expand operators, where DTensor's shape is the
|
|
# global shape.
|
|
with _enable_compile():
|
|
# FIXME(@mrshenli): functionalization does not work for our use
|
|
# case yet. Use explicit decompositions for foreach ops.
|
|
# Remove this when the following issue is addressed.
|
|
# Issue: https://github.com/pytorch/pytorch/issues/97852
|
|
gm = make_fx(
|
|
partial(stateless_func, func),
|
|
tracing_mode="symbolic",
|
|
decomposition_table={aten._foreach_add_.List: _foreach_add_decomp},
|
|
_allow_non_fake_inputs=False,
|
|
)(args, kwargs, named_states, params_and_buffers)
|
|
|
|
# 4. Use DTensor to insert collectives
|
|
gm, name_to_spec = _dtensor_expand(
|
|
gm, args, kwargs, named_states, params_and_buffers
|
|
)
|
|
|
|
# 5. Replace previously inserted dummy ones with real graphs.
|
|
if module_override:
|
|
for _, override in module_override.items():
|
|
gm = override.transform(gm, name_to_spec)
|
|
|
|
return _CompiledResult(gm, mod, opt, named_states, params_and_buffers)
|
|
|
|
|
|
def _foreach_add_decomp(self, other, alpha=1):
|
|
self_updated = aten._foreach_add.List(self, other, alpha=alpha)
|
|
for s, s_u in zip(self, self_updated):
|
|
s.copy_(s_u)
|
|
|
|
|
|
# Note that the Python convention of __dict__ requires the key to be str.
|
|
# TODO: ensure the key is unique.
|
|
COMPILED_OBJECT_KEY = "_compiled_obj"
|
|
|
|
|
|
def compile(
|
|
module_override: Optional[Dict[Type[Any], Override]] = None,
|
|
gm_transformation: Optional[
|
|
Callable[[fx.GraphModule], fx.GraphModule]
|
|
] = None,
|
|
):
|
|
r"""
|
|
Compile and optimize a callable, which can be a train step within a training
|
|
loop. This method will extract :class:`nn.Module` and :class:`torch.optim.Optimizer`
|
|
instances from the input arguments and trace operations applied to their
|
|
parameters and states.
|
|
|
|
Args:
|
|
module_override (Optional[Dict[Type[Any], Override]]): a dictionary maps
|
|
from target :class:`nn.Module` types to :class:`Override` objects.
|
|
The :class:`Override` objects provide :class:`nn.Module` replacements
|
|
during tracing and a graph transformation function after tracing.
|
|
(Default: ``None``)
|
|
gm_transformation (Optional[Callable[fx.GraphModule, fx.GraphModule]]):
|
|
a callback that will be called after the original callable is
|
|
compiled and distributed (usually after the first iteration) to
|
|
transform the compiled GraphModule into a new optimized one.
|
|
"""
|
|
|
|
def inner(func: Callable):
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
first_iter = False
|
|
# Put the COMPILED_OBJECT_KEY in ``wrapper`` instead of ``func`` as
|
|
# ``wrapper`` is the one that users will get.
|
|
compiled_obj = wrapper.__dict__.get(COMPILED_OBJECT_KEY, None)
|
|
if compiled_obj is None:
|
|
first_iter = True
|
|
compiled_obj = _compile(func, module_override, *args, **kwargs)
|
|
wrapper.__dict__[COMPILED_OBJECT_KEY] = compiled_obj
|
|
|
|
with torch.no_grad():
|
|
# N.B.: we don't need autograd as backward has already been
|
|
# captured in the graph.
|
|
output = compiled_obj.gm(
|
|
args,
|
|
kwargs,
|
|
compiled_obj.named_states,
|
|
compiled_obj.params_and_buffers,
|
|
)[0]
|
|
if first_iter and gm_transformation:
|
|
# TODO: SPMD should provid a default and configurable
|
|
# transformation.
|
|
compiled_obj.gm = gm_transformation(compiled_obj.gm)
|
|
return output
|
|
|
|
return wrapper
|
|
|
|
return inner
|