[dtensor] support local_map as a decorator (#161353)

And extract it out as a convenience function for dynamo to wrap

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161353
Approved by: https://github.com/zpcore
This commit is contained in:
Simon Fan 2025-08-26 22:16:04 -07:00 committed by PyTorch MergeBot
parent 0e35805030
commit 15670f9075
2 changed files with 132 additions and 104 deletions

View File

@ -1,6 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
from functools import partial
import torch
import torch.distributed._functional_collectives as funcol
@ -50,8 +49,7 @@ def mm_allreduce_forward(device_mesh, A, B):
return funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait()
@partial(
local_map,
@local_map(
out_placements=replicate,
in_placements=(None, col_wise, row_wise),
)

View File

@ -24,10 +24,10 @@ OutputPlacements = Union[PlacementType, tuple[PlacementType, ...]]
def local_map(
func: Callable,
out_placements: OutputPlacements,
in_placements: Optional[InputPlacements] = None,
in_grad_placements: Optional[InputPlacements] = None,
func: Optional[Callable] = None,
out_placements: OutputPlacements = None,
in_placements: InputPlacements = None,
in_grad_placements: InputPlacements = None,
device_mesh: Optional[DeviceMesh] = None,
*,
redistribute_inputs: bool = False,
@ -133,7 +133,41 @@ def local_map(
.. note:: This API is currently experimental and subject to change
"""
def wrapped(device_mesh: Optional[DeviceMesh], *args, **kwargs):
if func is None:
# decorator mode
def decorated(func):
return local_map(
func=func,
out_placements=out_placements,
in_placements=in_placements,
in_grad_placements=in_grad_placements,
device_mesh=device_mesh,
redistribute_inputs=redistribute_inputs,
)
return decorated
return functools.partial(
_local_map_wrapped,
func,
out_placements,
in_placements,
in_grad_placements,
device_mesh,
redistribute_inputs,
)
def _local_map_wrapped(
func: Callable,
out_placements: OutputPlacements,
in_placements: InputPlacements,
in_grad_placements: InputPlacements,
device_mesh: Optional[DeviceMesh],
redistribute_inputs: bool,
*args,
**kwargs,
):
# process input args
flat_args, args_spec = pytree.tree_flatten(args)
if in_placements is not None:
@ -214,9 +248,7 @@ def local_map(
flat_dist_out = []
out_placements_tuple = (
out_placements
if isinstance(out_placements, tuple)
else (out_placements,)
out_placements if isinstance(out_placements, tuple) else (out_placements,)
)
assert len(flat_out) == len(out_placements_tuple), (
"local_map requires one PlacementType be provided for each output value,"
@ -242,5 +274,3 @@ def local_map(
return pytree.tree_unflatten(flat_dist_out, out_spec)
else:
return out
return functools.partial(wrapped, device_mesh)