mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[dtensor] support local_map as a decorator (#161353)
And extract it out as a convenience function for dynamo to wrap Pull Request resolved: https://github.com/pytorch/pytorch/pull/161353 Approved by: https://github.com/zpcore
This commit is contained in:
parent
0e35805030
commit
15670f9075
|
|
@ -1,6 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
|
|
@ -50,8 +49,7 @@ def mm_allreduce_forward(device_mesh, A, B):
|
|||
return funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait()
|
||||
|
||||
|
||||
@partial(
|
||||
local_map,
|
||||
@local_map(
|
||||
out_placements=replicate,
|
||||
in_placements=(None, col_wise, row_wise),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -24,10 +24,10 @@ OutputPlacements = Union[PlacementType, tuple[PlacementType, ...]]
|
|||
|
||||
|
||||
def local_map(
|
||||
func: Callable,
|
||||
out_placements: OutputPlacements,
|
||||
in_placements: Optional[InputPlacements] = None,
|
||||
in_grad_placements: Optional[InputPlacements] = None,
|
||||
func: Optional[Callable] = None,
|
||||
out_placements: OutputPlacements = None,
|
||||
in_placements: InputPlacements = None,
|
||||
in_grad_placements: InputPlacements = None,
|
||||
device_mesh: Optional[DeviceMesh] = None,
|
||||
*,
|
||||
redistribute_inputs: bool = False,
|
||||
|
|
@ -133,7 +133,41 @@ def local_map(
|
|||
.. note:: This API is currently experimental and subject to change
|
||||
"""
|
||||
|
||||
def wrapped(device_mesh: Optional[DeviceMesh], *args, **kwargs):
|
||||
if func is None:
|
||||
# decorator mode
|
||||
def decorated(func):
|
||||
return local_map(
|
||||
func=func,
|
||||
out_placements=out_placements,
|
||||
in_placements=in_placements,
|
||||
in_grad_placements=in_grad_placements,
|
||||
device_mesh=device_mesh,
|
||||
redistribute_inputs=redistribute_inputs,
|
||||
)
|
||||
|
||||
return decorated
|
||||
|
||||
return functools.partial(
|
||||
_local_map_wrapped,
|
||||
func,
|
||||
out_placements,
|
||||
in_placements,
|
||||
in_grad_placements,
|
||||
device_mesh,
|
||||
redistribute_inputs,
|
||||
)
|
||||
|
||||
|
||||
def _local_map_wrapped(
|
||||
func: Callable,
|
||||
out_placements: OutputPlacements,
|
||||
in_placements: InputPlacements,
|
||||
in_grad_placements: InputPlacements,
|
||||
device_mesh: Optional[DeviceMesh],
|
||||
redistribute_inputs: bool,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
# process input args
|
||||
flat_args, args_spec = pytree.tree_flatten(args)
|
||||
if in_placements is not None:
|
||||
|
|
@ -214,9 +248,7 @@ def local_map(
|
|||
|
||||
flat_dist_out = []
|
||||
out_placements_tuple = (
|
||||
out_placements
|
||||
if isinstance(out_placements, tuple)
|
||||
else (out_placements,)
|
||||
out_placements if isinstance(out_placements, tuple) else (out_placements,)
|
||||
)
|
||||
assert len(flat_out) == len(out_placements_tuple), (
|
||||
"local_map requires one PlacementType be provided for each output value,"
|
||||
|
|
@ -242,5 +274,3 @@ def local_map(
|
|||
return pytree.tree_unflatten(flat_dist_out, out_spec)
|
||||
else:
|
||||
return out
|
||||
|
||||
return functools.partial(wrapped, device_mesh)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user