From 15670f9075709eb0e6257d51dee0f79b51f7df1c Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 26 Aug 2025 22:16:04 -0700 Subject: [PATCH] [dtensor] support local_map as a decorator (#161353) And extract it out as a convenience function for dynamo to wrap Pull Request resolved: https://github.com/pytorch/pytorch/pull/161353 Approved by: https://github.com/zpcore --- .../tensor/experimental/test_local_map.py | 4 +- .../tensor/experimental/_func_map.py | 232 ++++++++++-------- 2 files changed, 132 insertions(+), 104 deletions(-) diff --git a/test/distributed/tensor/experimental/test_local_map.py b/test/distributed/tensor/experimental/test_local_map.py index 1e1b4fa8f27..dad23226363 100644 --- a/test/distributed/tensor/experimental/test_local_map.py +++ b/test/distributed/tensor/experimental/test_local_map.py @@ -1,6 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates # Owner(s): ["oncall: distributed"] -from functools import partial import torch import torch.distributed._functional_collectives as funcol @@ -50,8 +49,7 @@ def mm_allreduce_forward(device_mesh, A, B): return funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait() -@partial( - local_map, +@local_map( out_placements=replicate, in_placements=(None, col_wise, row_wise), ) diff --git a/torch/distributed/tensor/experimental/_func_map.py b/torch/distributed/tensor/experimental/_func_map.py index fd91328c0b3..31cdd0f9a06 100644 --- a/torch/distributed/tensor/experimental/_func_map.py +++ b/torch/distributed/tensor/experimental/_func_map.py @@ -24,10 +24,10 @@ OutputPlacements = Union[PlacementType, tuple[PlacementType, ...]] def local_map( - func: Callable, - out_placements: OutputPlacements, - in_placements: Optional[InputPlacements] = None, - in_grad_placements: Optional[InputPlacements] = None, + func: Optional[Callable] = None, + out_placements: OutputPlacements = None, + in_placements: InputPlacements = None, + in_grad_placements: InputPlacements = None, device_mesh: Optional[DeviceMesh] = None, *, redistribute_inputs: bool = False, @@ -133,114 +133,144 @@ def local_map( .. note:: This API is currently experimental and subject to change """ - def wrapped(device_mesh: Optional[DeviceMesh], *args, **kwargs): - # process input args - flat_args, args_spec = pytree.tree_flatten(args) - if in_placements is not None: - assert len(in_placements) == len(flat_args), ( - f"in_placements length {len(in_placements)} does not match the number " - f"of input args {len(flat_args)}!" + if func is None: + # decorator mode + def decorated(func): + return local_map( + func=func, + out_placements=out_placements, + in_placements=in_placements, + in_grad_placements=in_grad_placements, + device_mesh=device_mesh, + redistribute_inputs=redistribute_inputs, ) - # we assume every DTensor object is placed on the same device mesh - flat_local_args = [] - seen_dtensor_arg = False - for idx, arg in enumerate(flat_args): - if isinstance(arg, DTensor): - # TODO: the current code doesn't consider the uneven sharding case - # Need to think about what the consequence is when the input DTensor - # is uneven sharded. - if device_mesh is None: # infer device mesh from the DTensor arg - device_mesh = arg.device_mesh + return decorated - # this function is applied to at least one DTensor argument - seen_dtensor_arg = True + return functools.partial( + _local_map_wrapped, + func, + out_placements, + in_placements, + in_grad_placements, + device_mesh, + redistribute_inputs, + ) - if in_placements is not None: - spec = in_placements[idx] - assert spec is not None, ( - f"DTensor input {arg} expects placements but received {spec}!" - ) - if not isinstance(spec, tuple): - spec = tuple(spec) +def _local_map_wrapped( + func: Callable, + out_placements: OutputPlacements, + in_placements: InputPlacements, + in_grad_placements: InputPlacements, + device_mesh: Optional[DeviceMesh], + redistribute_inputs: bool, + *args, + **kwargs, +): + # process input args + flat_args, args_spec = pytree.tree_flatten(args) + if in_placements is not None: + assert len(in_placements) == len(flat_args), ( + f"in_placements length {len(in_placements)} does not match the number " + f"of input args {len(flat_args)}!" + ) - if arg.placements != spec: - if redistribute_inputs: - # redistribute to input placements - arg = arg.redistribute(placements=spec) - else: - raise ValueError( - f"arg {arg} in local_map has a mismatched placements: " - f"arg placements is {arg.placements} but the input " - f"placements is {spec}! " - "If redistribute_inputs is wanted, set " - "redistribute_inputs=True to local_map." - ) + # we assume every DTensor object is placed on the same device mesh + flat_local_args = [] + seen_dtensor_arg = False + for idx, arg in enumerate(flat_args): + if isinstance(arg, DTensor): + # TODO: the current code doesn't consider the uneven sharding case + # Need to think about what the consequence is when the input DTensor + # is uneven sharded. + if device_mesh is None: # infer device mesh from the DTensor arg + device_mesh = arg.device_mesh - if in_grad_placements is not None: - spec = in_grad_placements[idx] - assert spec is not None, ( - f"DTensor input {arg} expects in grad placements but received {spec}!" - ) - if not isinstance(spec, tuple): - spec = tuple(spec) - local_arg = arg.to_local(grad_placements=spec) - else: - local_arg = arg.to_local() + # this function is applied to at least one DTensor argument + seen_dtensor_arg = True - if isinstance(local_arg, AsyncCollectiveTensor): - local_arg = local_arg.wait() + if in_placements is not None: + spec = in_placements[idx] + assert spec is not None, ( + f"DTensor input {arg} expects placements but received {spec}!" + ) - flat_local_args.append(local_arg) + if not isinstance(spec, tuple): + spec = tuple(spec) + + if arg.placements != spec: + if redistribute_inputs: + # redistribute to input placements + arg = arg.redistribute(placements=spec) + else: + raise ValueError( + f"arg {arg} in local_map has a mismatched placements: " + f"arg placements is {arg.placements} but the input " + f"placements is {spec}! " + "If redistribute_inputs is wanted, set " + "redistribute_inputs=True to local_map." + ) + + if in_grad_placements is not None: + spec = in_grad_placements[idx] + assert spec is not None, ( + f"DTensor input {arg} expects in grad placements but received {spec}!" + ) + if not isinstance(spec, tuple): + spec = tuple(spec) + local_arg = arg.to_local(grad_placements=spec) else: - # Non-Tensor input must have None in `in_placements` - if in_placements is not None and not isinstance(arg, torch.Tensor): - spec = in_placements[idx] - assert spec is None, ( - f"Non-Tensor input {arg} expects None placements " - f"but received {spec}!" - ) + local_arg = arg.to_local() - flat_local_args.append(arg) + if isinstance(local_arg, AsyncCollectiveTensor): + local_arg = local_arg.wait() - local_args = pytree.tree_unflatten(flat_local_args, args_spec) - - out = func(*local_args, **kwargs) - - if seen_dtensor_arg: - # process output to be DTensor if we've seen DTensor inputs - flat_out, out_spec = pytree.tree_flatten(out) - - flat_dist_out = [] - out_placements_tuple = ( - out_placements - if isinstance(out_placements, tuple) - else (out_placements,) - ) - assert len(flat_out) == len(out_placements_tuple), ( - "local_map requires one PlacementType be provided for each output value," - f" received {len(out_placements_tuple)} out_placements but" - f" {len(flat_out)} is expected!" - ) - for out, spec in zip(flat_out, out_placements_tuple): - if isinstance(out, torch.Tensor): - assert not isinstance(out, DTensor), ( - f"torch.Tensor output expected but received {type(out)}: {out}" - ) - - flat_dist_out.append( - DTensor.from_local(out, device_mesh, spec, run_check=False) - ) - else: - assert spec is None, ( - f"Non-tensor output {out} expects None placements but received {spec}!" - ) - - flat_dist_out.append(out) - - return pytree.tree_unflatten(flat_dist_out, out_spec) + flat_local_args.append(local_arg) else: - return out + # Non-Tensor input must have None in `in_placements` + if in_placements is not None and not isinstance(arg, torch.Tensor): + spec = in_placements[idx] + assert spec is None, ( + f"Non-Tensor input {arg} expects None placements " + f"but received {spec}!" + ) - return functools.partial(wrapped, device_mesh) + flat_local_args.append(arg) + + local_args = pytree.tree_unflatten(flat_local_args, args_spec) + + out = func(*local_args, **kwargs) + + if seen_dtensor_arg: + # process output to be DTensor if we've seen DTensor inputs + flat_out, out_spec = pytree.tree_flatten(out) + + flat_dist_out = [] + out_placements_tuple = ( + out_placements if isinstance(out_placements, tuple) else (out_placements,) + ) + assert len(flat_out) == len(out_placements_tuple), ( + "local_map requires one PlacementType be provided for each output value," + f" received {len(out_placements_tuple)} out_placements but" + f" {len(flat_out)} is expected!" + ) + for out, spec in zip(flat_out, out_placements_tuple): + if isinstance(out, torch.Tensor): + assert not isinstance(out, DTensor), ( + f"torch.Tensor output expected but received {type(out)}: {out}" + ) + + flat_dist_out.append( + DTensor.from_local(out, device_mesh, spec, run_check=False) + ) + else: + assert spec is None, ( + f"Non-tensor output {out} expects None placements but received {spec}!" + ) + + flat_dist_out.append(out) + + return pytree.tree_unflatten(flat_dist_out, out_spec) + else: + return out