[dtensor] support local_map as a decorator (#161353)

And extract it out as a convenience function for dynamo to wrap

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161353
Approved by: https://github.com/zpcore
This commit is contained in:
Simon Fan 2025-08-26 22:16:04 -07:00 committed by PyTorch MergeBot
parent 0e35805030
commit 15670f9075
2 changed files with 132 additions and 104 deletions

View File

@ -1,6 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates # Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"] # Owner(s): ["oncall: distributed"]
from functools import partial
import torch import torch
import torch.distributed._functional_collectives as funcol import torch.distributed._functional_collectives as funcol
@ -50,8 +49,7 @@ def mm_allreduce_forward(device_mesh, A, B):
return funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait() return funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait()
@partial( @local_map(
local_map,
out_placements=replicate, out_placements=replicate,
in_placements=(None, col_wise, row_wise), in_placements=(None, col_wise, row_wise),
) )

View File

@ -24,10 +24,10 @@ OutputPlacements = Union[PlacementType, tuple[PlacementType, ...]]
def local_map( def local_map(
func: Callable, func: Optional[Callable] = None,
out_placements: OutputPlacements, out_placements: OutputPlacements = None,
in_placements: Optional[InputPlacements] = None, in_placements: InputPlacements = None,
in_grad_placements: Optional[InputPlacements] = None, in_grad_placements: InputPlacements = None,
device_mesh: Optional[DeviceMesh] = None, device_mesh: Optional[DeviceMesh] = None,
*, *,
redistribute_inputs: bool = False, redistribute_inputs: bool = False,
@ -133,114 +133,144 @@ def local_map(
.. note:: This API is currently experimental and subject to change .. note:: This API is currently experimental and subject to change
""" """
def wrapped(device_mesh: Optional[DeviceMesh], *args, **kwargs): if func is None:
# process input args # decorator mode
flat_args, args_spec = pytree.tree_flatten(args) def decorated(func):
if in_placements is not None: return local_map(
assert len(in_placements) == len(flat_args), ( func=func,
f"in_placements length {len(in_placements)} does not match the number " out_placements=out_placements,
f"of input args {len(flat_args)}!" in_placements=in_placements,
in_grad_placements=in_grad_placements,
device_mesh=device_mesh,
redistribute_inputs=redistribute_inputs,
) )
# we assume every DTensor object is placed on the same device mesh return decorated
flat_local_args = []
seen_dtensor_arg = False
for idx, arg in enumerate(flat_args):
if isinstance(arg, DTensor):
# TODO: the current code doesn't consider the uneven sharding case
# Need to think about what the consequence is when the input DTensor
# is uneven sharded.
if device_mesh is None: # infer device mesh from the DTensor arg
device_mesh = arg.device_mesh
# this function is applied to at least one DTensor argument return functools.partial(
seen_dtensor_arg = True _local_map_wrapped,
func,
out_placements,
in_placements,
in_grad_placements,
device_mesh,
redistribute_inputs,
)
if in_placements is not None:
spec = in_placements[idx]
assert spec is not None, (
f"DTensor input {arg} expects placements but received {spec}!"
)
if not isinstance(spec, tuple): def _local_map_wrapped(
spec = tuple(spec) func: Callable,
out_placements: OutputPlacements,
in_placements: InputPlacements,
in_grad_placements: InputPlacements,
device_mesh: Optional[DeviceMesh],
redistribute_inputs: bool,
*args,
**kwargs,
):
# process input args
flat_args, args_spec = pytree.tree_flatten(args)
if in_placements is not None:
assert len(in_placements) == len(flat_args), (
f"in_placements length {len(in_placements)} does not match the number "
f"of input args {len(flat_args)}!"
)
if arg.placements != spec: # we assume every DTensor object is placed on the same device mesh
if redistribute_inputs: flat_local_args = []
# redistribute to input placements seen_dtensor_arg = False
arg = arg.redistribute(placements=spec) for idx, arg in enumerate(flat_args):
else: if isinstance(arg, DTensor):
raise ValueError( # TODO: the current code doesn't consider the uneven sharding case
f"arg {arg} in local_map has a mismatched placements: " # Need to think about what the consequence is when the input DTensor
f"arg placements is {arg.placements} but the input " # is uneven sharded.
f"placements is {spec}! " if device_mesh is None: # infer device mesh from the DTensor arg
"If redistribute_inputs is wanted, set " device_mesh = arg.device_mesh
"redistribute_inputs=True to local_map."
)
if in_grad_placements is not None: # this function is applied to at least one DTensor argument
spec = in_grad_placements[idx] seen_dtensor_arg = True
assert spec is not None, (
f"DTensor input {arg} expects in grad placements but received {spec}!"
)
if not isinstance(spec, tuple):
spec = tuple(spec)
local_arg = arg.to_local(grad_placements=spec)
else:
local_arg = arg.to_local()
if isinstance(local_arg, AsyncCollectiveTensor): if in_placements is not None:
local_arg = local_arg.wait() spec = in_placements[idx]
assert spec is not None, (
f"DTensor input {arg} expects placements but received {spec}!"
)
flat_local_args.append(local_arg) if not isinstance(spec, tuple):
spec = tuple(spec)
if arg.placements != spec:
if redistribute_inputs:
# redistribute to input placements
arg = arg.redistribute(placements=spec)
else:
raise ValueError(
f"arg {arg} in local_map has a mismatched placements: "
f"arg placements is {arg.placements} but the input "
f"placements is {spec}! "
"If redistribute_inputs is wanted, set "
"redistribute_inputs=True to local_map."
)
if in_grad_placements is not None:
spec = in_grad_placements[idx]
assert spec is not None, (
f"DTensor input {arg} expects in grad placements but received {spec}!"
)
if not isinstance(spec, tuple):
spec = tuple(spec)
local_arg = arg.to_local(grad_placements=spec)
else: else:
# Non-Tensor input must have None in `in_placements` local_arg = arg.to_local()
if in_placements is not None and not isinstance(arg, torch.Tensor):
spec = in_placements[idx]
assert spec is None, (
f"Non-Tensor input {arg} expects None placements "
f"but received {spec}!"
)
flat_local_args.append(arg) if isinstance(local_arg, AsyncCollectiveTensor):
local_arg = local_arg.wait()
local_args = pytree.tree_unflatten(flat_local_args, args_spec) flat_local_args.append(local_arg)
out = func(*local_args, **kwargs)
if seen_dtensor_arg:
# process output to be DTensor if we've seen DTensor inputs
flat_out, out_spec = pytree.tree_flatten(out)
flat_dist_out = []
out_placements_tuple = (
out_placements
if isinstance(out_placements, tuple)
else (out_placements,)
)
assert len(flat_out) == len(out_placements_tuple), (
"local_map requires one PlacementType be provided for each output value,"
f" received {len(out_placements_tuple)} out_placements but"
f" {len(flat_out)} is expected!"
)
for out, spec in zip(flat_out, out_placements_tuple):
if isinstance(out, torch.Tensor):
assert not isinstance(out, DTensor), (
f"torch.Tensor output expected but received {type(out)}: {out}"
)
flat_dist_out.append(
DTensor.from_local(out, device_mesh, spec, run_check=False)
)
else:
assert spec is None, (
f"Non-tensor output {out} expects None placements but received {spec}!"
)
flat_dist_out.append(out)
return pytree.tree_unflatten(flat_dist_out, out_spec)
else: else:
return out # Non-Tensor input must have None in `in_placements`
if in_placements is not None and not isinstance(arg, torch.Tensor):
spec = in_placements[idx]
assert spec is None, (
f"Non-Tensor input {arg} expects None placements "
f"but received {spec}!"
)
return functools.partial(wrapped, device_mesh) flat_local_args.append(arg)
local_args = pytree.tree_unflatten(flat_local_args, args_spec)
out = func(*local_args, **kwargs)
if seen_dtensor_arg:
# process output to be DTensor if we've seen DTensor inputs
flat_out, out_spec = pytree.tree_flatten(out)
flat_dist_out = []
out_placements_tuple = (
out_placements if isinstance(out_placements, tuple) else (out_placements,)
)
assert len(flat_out) == len(out_placements_tuple), (
"local_map requires one PlacementType be provided for each output value,"
f" received {len(out_placements_tuple)} out_placements but"
f" {len(flat_out)} is expected!"
)
for out, spec in zip(flat_out, out_placements_tuple):
if isinstance(out, torch.Tensor):
assert not isinstance(out, DTensor), (
f"torch.Tensor output expected but received {type(out)}: {out}"
)
flat_dist_out.append(
DTensor.from_local(out, device_mesh, spec, run_check=False)
)
else:
assert spec is None, (
f"Non-tensor output {out} expects None placements but received {spec}!"
)
flat_dist_out.append(out)
return pytree.tree_unflatten(flat_dist_out, out_spec)
else:
return out