mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[dtensor] support local_map as a decorator (#161353)
And extract it out as a convenience function for dynamo to wrap Pull Request resolved: https://github.com/pytorch/pytorch/pull/161353 Approved by: https://github.com/zpcore
This commit is contained in:
parent
0e35805030
commit
15670f9075
|
|
@ -1,6 +1,5 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||||
# Owner(s): ["oncall: distributed"]
|
# Owner(s): ["oncall: distributed"]
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed._functional_collectives as funcol
|
import torch.distributed._functional_collectives as funcol
|
||||||
|
|
@ -50,8 +49,7 @@ def mm_allreduce_forward(device_mesh, A, B):
|
||||||
return funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait()
|
return funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait()
|
||||||
|
|
||||||
|
|
||||||
@partial(
|
@local_map(
|
||||||
local_map,
|
|
||||||
out_placements=replicate,
|
out_placements=replicate,
|
||||||
in_placements=(None, col_wise, row_wise),
|
in_placements=(None, col_wise, row_wise),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -24,10 +24,10 @@ OutputPlacements = Union[PlacementType, tuple[PlacementType, ...]]
|
||||||
|
|
||||||
|
|
||||||
def local_map(
|
def local_map(
|
||||||
func: Callable,
|
func: Optional[Callable] = None,
|
||||||
out_placements: OutputPlacements,
|
out_placements: OutputPlacements = None,
|
||||||
in_placements: Optional[InputPlacements] = None,
|
in_placements: InputPlacements = None,
|
||||||
in_grad_placements: Optional[InputPlacements] = None,
|
in_grad_placements: InputPlacements = None,
|
||||||
device_mesh: Optional[DeviceMesh] = None,
|
device_mesh: Optional[DeviceMesh] = None,
|
||||||
*,
|
*,
|
||||||
redistribute_inputs: bool = False,
|
redistribute_inputs: bool = False,
|
||||||
|
|
@ -133,114 +133,144 @@ def local_map(
|
||||||
.. note:: This API is currently experimental and subject to change
|
.. note:: This API is currently experimental and subject to change
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapped(device_mesh: Optional[DeviceMesh], *args, **kwargs):
|
if func is None:
|
||||||
# process input args
|
# decorator mode
|
||||||
flat_args, args_spec = pytree.tree_flatten(args)
|
def decorated(func):
|
||||||
if in_placements is not None:
|
return local_map(
|
||||||
assert len(in_placements) == len(flat_args), (
|
func=func,
|
||||||
f"in_placements length {len(in_placements)} does not match the number "
|
out_placements=out_placements,
|
||||||
f"of input args {len(flat_args)}!"
|
in_placements=in_placements,
|
||||||
|
in_grad_placements=in_grad_placements,
|
||||||
|
device_mesh=device_mesh,
|
||||||
|
redistribute_inputs=redistribute_inputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# we assume every DTensor object is placed on the same device mesh
|
return decorated
|
||||||
flat_local_args = []
|
|
||||||
seen_dtensor_arg = False
|
|
||||||
for idx, arg in enumerate(flat_args):
|
|
||||||
if isinstance(arg, DTensor):
|
|
||||||
# TODO: the current code doesn't consider the uneven sharding case
|
|
||||||
# Need to think about what the consequence is when the input DTensor
|
|
||||||
# is uneven sharded.
|
|
||||||
if device_mesh is None: # infer device mesh from the DTensor arg
|
|
||||||
device_mesh = arg.device_mesh
|
|
||||||
|
|
||||||
# this function is applied to at least one DTensor argument
|
return functools.partial(
|
||||||
seen_dtensor_arg = True
|
_local_map_wrapped,
|
||||||
|
func,
|
||||||
|
out_placements,
|
||||||
|
in_placements,
|
||||||
|
in_grad_placements,
|
||||||
|
device_mesh,
|
||||||
|
redistribute_inputs,
|
||||||
|
)
|
||||||
|
|
||||||
if in_placements is not None:
|
|
||||||
spec = in_placements[idx]
|
|
||||||
assert spec is not None, (
|
|
||||||
f"DTensor input {arg} expects placements but received {spec}!"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not isinstance(spec, tuple):
|
def _local_map_wrapped(
|
||||||
spec = tuple(spec)
|
func: Callable,
|
||||||
|
out_placements: OutputPlacements,
|
||||||
|
in_placements: InputPlacements,
|
||||||
|
in_grad_placements: InputPlacements,
|
||||||
|
device_mesh: Optional[DeviceMesh],
|
||||||
|
redistribute_inputs: bool,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# process input args
|
||||||
|
flat_args, args_spec = pytree.tree_flatten(args)
|
||||||
|
if in_placements is not None:
|
||||||
|
assert len(in_placements) == len(flat_args), (
|
||||||
|
f"in_placements length {len(in_placements)} does not match the number "
|
||||||
|
f"of input args {len(flat_args)}!"
|
||||||
|
)
|
||||||
|
|
||||||
if arg.placements != spec:
|
# we assume every DTensor object is placed on the same device mesh
|
||||||
if redistribute_inputs:
|
flat_local_args = []
|
||||||
# redistribute to input placements
|
seen_dtensor_arg = False
|
||||||
arg = arg.redistribute(placements=spec)
|
for idx, arg in enumerate(flat_args):
|
||||||
else:
|
if isinstance(arg, DTensor):
|
||||||
raise ValueError(
|
# TODO: the current code doesn't consider the uneven sharding case
|
||||||
f"arg {arg} in local_map has a mismatched placements: "
|
# Need to think about what the consequence is when the input DTensor
|
||||||
f"arg placements is {arg.placements} but the input "
|
# is uneven sharded.
|
||||||
f"placements is {spec}! "
|
if device_mesh is None: # infer device mesh from the DTensor arg
|
||||||
"If redistribute_inputs is wanted, set "
|
device_mesh = arg.device_mesh
|
||||||
"redistribute_inputs=True to local_map."
|
|
||||||
)
|
|
||||||
|
|
||||||
if in_grad_placements is not None:
|
# this function is applied to at least one DTensor argument
|
||||||
spec = in_grad_placements[idx]
|
seen_dtensor_arg = True
|
||||||
assert spec is not None, (
|
|
||||||
f"DTensor input {arg} expects in grad placements but received {spec}!"
|
|
||||||
)
|
|
||||||
if not isinstance(spec, tuple):
|
|
||||||
spec = tuple(spec)
|
|
||||||
local_arg = arg.to_local(grad_placements=spec)
|
|
||||||
else:
|
|
||||||
local_arg = arg.to_local()
|
|
||||||
|
|
||||||
if isinstance(local_arg, AsyncCollectiveTensor):
|
if in_placements is not None:
|
||||||
local_arg = local_arg.wait()
|
spec = in_placements[idx]
|
||||||
|
assert spec is not None, (
|
||||||
|
f"DTensor input {arg} expects placements but received {spec}!"
|
||||||
|
)
|
||||||
|
|
||||||
flat_local_args.append(local_arg)
|
if not isinstance(spec, tuple):
|
||||||
|
spec = tuple(spec)
|
||||||
|
|
||||||
|
if arg.placements != spec:
|
||||||
|
if redistribute_inputs:
|
||||||
|
# redistribute to input placements
|
||||||
|
arg = arg.redistribute(placements=spec)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"arg {arg} in local_map has a mismatched placements: "
|
||||||
|
f"arg placements is {arg.placements} but the input "
|
||||||
|
f"placements is {spec}! "
|
||||||
|
"If redistribute_inputs is wanted, set "
|
||||||
|
"redistribute_inputs=True to local_map."
|
||||||
|
)
|
||||||
|
|
||||||
|
if in_grad_placements is not None:
|
||||||
|
spec = in_grad_placements[idx]
|
||||||
|
assert spec is not None, (
|
||||||
|
f"DTensor input {arg} expects in grad placements but received {spec}!"
|
||||||
|
)
|
||||||
|
if not isinstance(spec, tuple):
|
||||||
|
spec = tuple(spec)
|
||||||
|
local_arg = arg.to_local(grad_placements=spec)
|
||||||
else:
|
else:
|
||||||
# Non-Tensor input must have None in `in_placements`
|
local_arg = arg.to_local()
|
||||||
if in_placements is not None and not isinstance(arg, torch.Tensor):
|
|
||||||
spec = in_placements[idx]
|
|
||||||
assert spec is None, (
|
|
||||||
f"Non-Tensor input {arg} expects None placements "
|
|
||||||
f"but received {spec}!"
|
|
||||||
)
|
|
||||||
|
|
||||||
flat_local_args.append(arg)
|
if isinstance(local_arg, AsyncCollectiveTensor):
|
||||||
|
local_arg = local_arg.wait()
|
||||||
|
|
||||||
local_args = pytree.tree_unflatten(flat_local_args, args_spec)
|
flat_local_args.append(local_arg)
|
||||||
|
|
||||||
out = func(*local_args, **kwargs)
|
|
||||||
|
|
||||||
if seen_dtensor_arg:
|
|
||||||
# process output to be DTensor if we've seen DTensor inputs
|
|
||||||
flat_out, out_spec = pytree.tree_flatten(out)
|
|
||||||
|
|
||||||
flat_dist_out = []
|
|
||||||
out_placements_tuple = (
|
|
||||||
out_placements
|
|
||||||
if isinstance(out_placements, tuple)
|
|
||||||
else (out_placements,)
|
|
||||||
)
|
|
||||||
assert len(flat_out) == len(out_placements_tuple), (
|
|
||||||
"local_map requires one PlacementType be provided for each output value,"
|
|
||||||
f" received {len(out_placements_tuple)} out_placements but"
|
|
||||||
f" {len(flat_out)} is expected!"
|
|
||||||
)
|
|
||||||
for out, spec in zip(flat_out, out_placements_tuple):
|
|
||||||
if isinstance(out, torch.Tensor):
|
|
||||||
assert not isinstance(out, DTensor), (
|
|
||||||
f"torch.Tensor output expected but received {type(out)}: {out}"
|
|
||||||
)
|
|
||||||
|
|
||||||
flat_dist_out.append(
|
|
||||||
DTensor.from_local(out, device_mesh, spec, run_check=False)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert spec is None, (
|
|
||||||
f"Non-tensor output {out} expects None placements but received {spec}!"
|
|
||||||
)
|
|
||||||
|
|
||||||
flat_dist_out.append(out)
|
|
||||||
|
|
||||||
return pytree.tree_unflatten(flat_dist_out, out_spec)
|
|
||||||
else:
|
else:
|
||||||
return out
|
# Non-Tensor input must have None in `in_placements`
|
||||||
|
if in_placements is not None and not isinstance(arg, torch.Tensor):
|
||||||
|
spec = in_placements[idx]
|
||||||
|
assert spec is None, (
|
||||||
|
f"Non-Tensor input {arg} expects None placements "
|
||||||
|
f"but received {spec}!"
|
||||||
|
)
|
||||||
|
|
||||||
return functools.partial(wrapped, device_mesh)
|
flat_local_args.append(arg)
|
||||||
|
|
||||||
|
local_args = pytree.tree_unflatten(flat_local_args, args_spec)
|
||||||
|
|
||||||
|
out = func(*local_args, **kwargs)
|
||||||
|
|
||||||
|
if seen_dtensor_arg:
|
||||||
|
# process output to be DTensor if we've seen DTensor inputs
|
||||||
|
flat_out, out_spec = pytree.tree_flatten(out)
|
||||||
|
|
||||||
|
flat_dist_out = []
|
||||||
|
out_placements_tuple = (
|
||||||
|
out_placements if isinstance(out_placements, tuple) else (out_placements,)
|
||||||
|
)
|
||||||
|
assert len(flat_out) == len(out_placements_tuple), (
|
||||||
|
"local_map requires one PlacementType be provided for each output value,"
|
||||||
|
f" received {len(out_placements_tuple)} out_placements but"
|
||||||
|
f" {len(flat_out)} is expected!"
|
||||||
|
)
|
||||||
|
for out, spec in zip(flat_out, out_placements_tuple):
|
||||||
|
if isinstance(out, torch.Tensor):
|
||||||
|
assert not isinstance(out, DTensor), (
|
||||||
|
f"torch.Tensor output expected but received {type(out)}: {out}"
|
||||||
|
)
|
||||||
|
|
||||||
|
flat_dist_out.append(
|
||||||
|
DTensor.from_local(out, device_mesh, spec, run_check=False)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert spec is None, (
|
||||||
|
f"Non-tensor output {out} expects None placements but received {spec}!"
|
||||||
|
)
|
||||||
|
|
||||||
|
flat_dist_out.append(out)
|
||||||
|
|
||||||
|
return pytree.tree_unflatten(flat_dist_out, out_spec)
|
||||||
|
else:
|
||||||
|
return out
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user