mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Apparently https://github.com/pytorch/pytorch/pull/78142 made torch.JIT allow for simple generator expressions which allows us to enable rules that replace unnecessary list comprehensions with generators in any/all. This was originally part of #99280 but I split it off into this PR so that it can be easily reverted should anything break. Pull Request resolved: https://github.com/pytorch/pytorch/pull/99890 Approved by: https://github.com/justinchuby, https://github.com/kit1980, https://github.com/malfet
949 lines
35 KiB
Python
949 lines
35 KiB
Python
import logging
|
|
import operator
|
|
from dataclasses import dataclass
|
|
from enum import auto, Enum
|
|
from functools import partial
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
cast,
|
|
Dict,
|
|
List,
|
|
Optional,
|
|
Sequence,
|
|
Set,
|
|
Tuple,
|
|
Union,
|
|
)
|
|
|
|
import torch
|
|
import torch.distributed._spmd.experimental_ops
|
|
import torch.fx as fx
|
|
import torch.nn as nn
|
|
from torch._functorch.aot_autograd import aot_module, make_boxed_func
|
|
from torch._guards import detect_fake_mode
|
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
from torch.distributed._spmd.aot_function_patch import patched_aot_function
|
|
from torch.distributed._spmd.comm_tensor import _get_tracer
|
|
from torch.distributed._spmd.distributed_graph import DistributedGraph
|
|
from torch.distributed._spmd.graph_utils import OP
|
|
from torch.distributed._spmd.log_utils import get_logger
|
|
|
|
from torch.distributed._tensor import DeviceMesh, DTensor
|
|
from torch.distributed._tensor.dispatch import (
|
|
_CURRENT_DECOMPOSITION_TABLE,
|
|
_operator_dispatch,
|
|
)
|
|
from torch.distributed._tensor.op_schema import OpSchema
|
|
from torch.distributed._tensor.placement_types import (
|
|
_Partial,
|
|
Placement,
|
|
Replicate,
|
|
Shard,
|
|
)
|
|
from torch.distributed._tensor.redistribute import _redistribute_with_local_tensor
|
|
from torch.fx.experimental.proxy_tensor import (
|
|
make_fx,
|
|
maybe_disable_fake_tensor_mode,
|
|
proxy_slot,
|
|
)
|
|
from torch.utils._pytree import tree_flatten, tree_map, tree_map_only, tree_unflatten
|
|
|
|
# patch aot_function so that we can pass the full (non-sharded) input to capture the graph
|
|
# pyre-fixme
|
|
torch._functorch.aot_autograd.aot_function = patched_aot_function # type: ignore[assignment]
|
|
|
|
logger: Optional[logging.Logger] = None
|
|
|
|
aten = torch.ops.aten
|
|
|
|
|
|
class TrainingPhase(Enum):
|
|
FORWARD = auto()
|
|
BACKWARD = auto()
|
|
|
|
|
|
@dataclass
|
|
class Schema:
|
|
mesh: DeviceMesh
|
|
placements: List[Placement]
|
|
|
|
|
|
@dataclass
|
|
class DSymInt:
|
|
"""
|
|
DSymInt represents a value retrieved by a SymInt op from a DTensor. DSymInt
|
|
helps View and Factory ops to determine the placement and shape of the
|
|
output tensor, as those operators either do not have an input DTensor or
|
|
the input DTensor is insufficient to determine the output tensor's placement.
|
|
"""
|
|
|
|
global_value: int # value that the SymInt evaluates to
|
|
local_value: int # vaue that this SymInt evaluates to on the local shard
|
|
mesh: DeviceMesh # device mesh of the DTensor where this SymInt is retrieved from
|
|
|
|
def is_shard(self) -> bool:
|
|
return self.local_value != self.global_value
|
|
|
|
@classmethod
|
|
def from_node(cls, node: fx.Node, dtensor: DTensor) -> "DSymInt":
|
|
dim: int = 0
|
|
if node.target == aten.sym_size:
|
|
dim = cast(int, node.args[1])
|
|
return cls(
|
|
global_value=dtensor.size(dim),
|
|
local_value=dtensor.to_local().size(dim),
|
|
mesh=dtensor.device_mesh,
|
|
)
|
|
elif node.target == aten.sym_numel:
|
|
return cls(
|
|
global_value=dtensor.numel(),
|
|
local_value=dtensor.to_local().numel(),
|
|
mesh=dtensor.device_mesh,
|
|
)
|
|
elif node.target == aten.sym_stride:
|
|
dim = cast(int, node.args[1])
|
|
return cls(
|
|
global_value=dtensor.stride(dim),
|
|
local_value=dtensor.to_local().stride(dim),
|
|
mesh=dtensor.device_mesh,
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"DSymInt does not support {node.target}")
|
|
|
|
|
|
def _is_partial_dtensor(obj: Any) -> bool:
|
|
"""check if object is 1) DTensor and 2) with any placement of _Partial"""
|
|
if not isinstance(obj, DTensor):
|
|
return False
|
|
|
|
is_partial = False
|
|
for placement in obj.placements:
|
|
if isinstance(placement, _Partial):
|
|
is_partial = True
|
|
break
|
|
|
|
return is_partial
|
|
|
|
|
|
def _dispatch_with_local_tensors(
|
|
op: torch._ops.OpOverload,
|
|
local_args: Tuple[Any, ...],
|
|
kwargs: Optional[Dict[str, Any]] = None,
|
|
specs: Optional[
|
|
Dict[
|
|
torch.Tensor,
|
|
Tuple[torch.Size, DeviceMesh, Sequence[Placement], Sequence[Placement]],
|
|
]
|
|
] = None,
|
|
) -> Any:
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
if specs is None:
|
|
specs = {}
|
|
|
|
def redistribute(arg: Any) -> Any:
|
|
return (
|
|
_redistribute_with_local_tensor(arg, *specs[arg]) # type: ignore[index]
|
|
if isinstance(arg, torch.Tensor) and arg in specs # type: ignore[operator]
|
|
else arg
|
|
)
|
|
|
|
# TODO: this is broken because it won't redistributed potential tensors on the kwargs
|
|
return op(*tree_map(redistribute, local_args), **kwargs)
|
|
|
|
|
|
# Figure out how to specify a type spec for the return specs value
|
|
# without the entire structure.
|
|
# pyre-fixme
|
|
def _update_specs_for_redistribute(args, target_schema, redistribute):
|
|
# Code adapted from pack_args_kwargs_with_local_tensor
|
|
flatten_args, args_tree_spec = tree_flatten(args)
|
|
flatten_args_schema, _ = tree_flatten(target_schema.args_schema)
|
|
|
|
specs: Dict[
|
|
torch.Tensor,
|
|
Tuple[
|
|
torch.Size,
|
|
DeviceMesh,
|
|
Sequence[Placement],
|
|
Sequence[Placement],
|
|
],
|
|
] = {}
|
|
for i, arg in enumerate(flatten_args):
|
|
if isinstance(arg, DTensor):
|
|
if redistribute:
|
|
specs[arg._local_tensor] = (
|
|
arg.size(),
|
|
flatten_args_schema[i].mesh,
|
|
arg.placements,
|
|
flatten_args_schema[i].placements,
|
|
)
|
|
flatten_args_schema[i] = arg._local_tensor
|
|
|
|
unflattened_args = tree_unflatten(flatten_args_schema, args_tree_spec)
|
|
return specs, unflattened_args
|
|
|
|
|
|
# When no tensor redistribution is required, we only need to update non-tensor args
|
|
# of the node according to op_schema and avoid building a GraphModule just for the
|
|
# node.
|
|
def _update_node_from_op_schema(node: torch.fx.Node, op_schema: OpSchema) -> None:
|
|
flat_args, args_tree_spec = tree_flatten(node.args)
|
|
flat_args_schema, _ = tree_flatten(op_schema.args_schema)
|
|
|
|
def is_sym_int_or_int(arg: Union[int, torch.fx.Node]) -> bool:
|
|
if isinstance(arg, torch.fx.Node):
|
|
return arg.target in [
|
|
aten.sym_size,
|
|
aten.sym_numel,
|
|
aten.sym_stride,
|
|
]
|
|
return isinstance(arg, int)
|
|
|
|
assert len(flat_args) == len(flat_args_schema)
|
|
for i, (arg, arg_schema) in enumerate(zip(flat_args, flat_args_schema)):
|
|
if is_sym_int_or_int(arg) and isinstance(arg_schema, int):
|
|
flat_args[i] = arg_schema
|
|
|
|
args = tree_unflatten(flat_args, args_tree_spec)
|
|
for idx, arg in enumerate(args):
|
|
node.update_arg(idx, arg)
|
|
return None
|
|
|
|
|
|
def _remap_arg(node_to_obj: Dict[fx.Node, Any], arg: Any) -> Any:
|
|
if isinstance(arg, torch.fx.Node):
|
|
obj = node_to_obj[arg]
|
|
if _get_tracer():
|
|
# This is a shared arg, already has a tracer from previous
|
|
# tracing. Delete the tracer.
|
|
del cast(Dict[Any, Any], obj.__dict__)[proxy_slot]
|
|
return obj
|
|
else:
|
|
return arg
|
|
|
|
|
|
def unpack_sizes_and_dims(
|
|
sizes: List[Union[DSymInt, int]], mesh: DeviceMesh
|
|
) -> Tuple[List[int], List[Placement]]:
|
|
local_sizes: List[int] = [
|
|
s.local_value if isinstance(s, DSymInt) else s for s in sizes
|
|
]
|
|
placements: List[Placement] = [
|
|
Shard(i)
|
|
for i, a in enumerate(sizes)
|
|
if (isinstance(a, DSymInt) and a.is_shard())
|
|
] or [Replicate()]
|
|
|
|
assert len(placements) == mesh.ndim, (
|
|
f"The number of sharded dimensions ({len(placements)}) must "
|
|
f"match number of dimensions in device mesh ({mesh.ndim})."
|
|
)
|
|
|
|
return local_sizes, placements
|
|
|
|
|
|
def binop_sym_int_consumer_rule(node: fx.Node, args: Tuple[Any, ...]) -> DTensor:
|
|
assert len(args) == 2, f"Expect two args but got op {node.target} with args {args}"
|
|
assert isinstance(
|
|
args[0], DTensor
|
|
), f"Expect 1st argument to be DTensor but got {args[0]}"
|
|
assert isinstance(args[1], list), f"Expect 2nd argument as list but got {args[1]}"
|
|
|
|
# extract sharded dimensions in the size list, the output DTensor should
|
|
# follow these placements.
|
|
local_sizes, placements = unpack_sizes_and_dims(args[1], args[0].device_mesh)
|
|
|
|
# set node args to real int sizes.
|
|
node.args = (node.args[0], local_sizes)
|
|
op = cast(torch._ops.OpOverload, node.target)
|
|
return DTensor.from_local(
|
|
local_tensor=op(args[0]._local_tensor, local_sizes),
|
|
device_mesh=args[0].device_mesh,
|
|
placements=placements,
|
|
run_check=False,
|
|
)
|
|
|
|
|
|
def slice_backwad_sym_int_consumer_rule(
|
|
node: fx.Node, args: Tuple[Any, ...]
|
|
) -> DTensor:
|
|
grad_output, input_sizes, dim, start, end, step = args
|
|
|
|
local_sizes: List[int] = [
|
|
s.local_value if isinstance(s, DSymInt) else s for s in input_sizes
|
|
]
|
|
|
|
input_tensor = torch.zeros(
|
|
local_sizes, device=grad_output.device, dtype=grad_output.dtype
|
|
)
|
|
return DTensor.from_local(
|
|
local_tensor=torch.slice_scatter(
|
|
input_tensor, grad_output.to_local(), dim, start, end, step
|
|
),
|
|
device_mesh=grad_output.device_mesh,
|
|
placements=grad_output.placements,
|
|
run_check=False,
|
|
)
|
|
|
|
|
|
def factory_with_sizes_rule(
|
|
node: fx.Node,
|
|
args: Tuple[Any, ...],
|
|
kwargs: Dict[str, Any],
|
|
default_mesh: DeviceMesh,
|
|
) -> DTensor:
|
|
flat_args = tree_flatten(args)[0]
|
|
assert not any(isinstance(a, DTensor) for a in flat_args), (
|
|
f"Not expect DTensor argument for factory op, but got {node.target} "
|
|
f"with arguments {args}."
|
|
)
|
|
assert isinstance(args[0], list), f"Expect 2nd argument as list but got {args[1]}"
|
|
|
|
local_sizes, placements = unpack_sizes_and_dims(args[0], default_mesh)
|
|
node.args = (local_sizes, *args[1:])
|
|
op = cast(torch._ops.OpOverload, node.target)
|
|
return DTensor.from_local(
|
|
local_tensor=op(*node.args, **kwargs),
|
|
device_mesh=default_mesh,
|
|
placements=placements,
|
|
run_check=False,
|
|
)
|
|
|
|
|
|
def factory_arange_rule(
|
|
node: fx.Node,
|
|
args: Tuple[Any, ...],
|
|
kwargs: Dict[str, Any],
|
|
default_mesh: DeviceMesh,
|
|
) -> DTensor:
|
|
node.args = tree_map(lambda a: a.local_value if isinstance(a, DSymInt) else a, args)
|
|
op = cast(torch._ops.OpOverload, node.target)
|
|
return DTensor.from_local(
|
|
local_tensor=op(*node.args, **kwargs),
|
|
device_mesh=default_mesh,
|
|
placements=[Replicate()],
|
|
run_check=False,
|
|
)
|
|
|
|
|
|
def default_factory_op_rule(
|
|
node: fx.Node,
|
|
args: Tuple[Any, ...],
|
|
kwargs: Dict[str, Any],
|
|
default_mesh: DeviceMesh,
|
|
) -> DTensor:
|
|
node.args, node.kwargs = args, kwargs
|
|
op = cast(torch._ops.OpOverload, node.target)
|
|
return DTensor.from_local(
|
|
local_tensor=op(*node.args, **node.kwargs),
|
|
device_mesh=default_mesh,
|
|
placements=[Replicate()],
|
|
run_check=False,
|
|
)
|
|
|
|
|
|
# Dispatch override for view and factory ops that consume SymInt arguments,
|
|
# where the output spec should follow dimension placement where the SymInt comes
|
|
# from.
|
|
VIEW_SYM_INT_CONSUMERS: Dict[torch._ops.OpOverload, Callable] = {
|
|
aten._unsafe_view.default: binop_sym_int_consumer_rule,
|
|
aten.expand.default: binop_sym_int_consumer_rule,
|
|
aten.slice_backward.default: slice_backwad_sym_int_consumer_rule,
|
|
aten.view.default: binop_sym_int_consumer_rule,
|
|
}
|
|
|
|
FACTORY_SYM_INT_CONSUMERS: Dict[torch._ops.OpOverload, Callable] = {
|
|
aten.full.default: factory_with_sizes_rule,
|
|
aten.arange.default: factory_arange_rule,
|
|
aten.arange.start: factory_arange_rule,
|
|
}
|
|
|
|
|
|
# Dispatch override for factory ops, as DTensor cannot propogate sharding spec
|
|
# without DTensor inputs.
|
|
FACTORY_OPS: Dict[torch._ops.OpOverload, Callable] = {
|
|
aten.scalar_tensor.default: default_factory_op_rule,
|
|
}
|
|
|
|
|
|
def _get_dtensor_dispatch_graph(
|
|
node: fx.Node,
|
|
node_to_obj: Dict[fx.Node, Any],
|
|
*,
|
|
force_make_fx: bool = False,
|
|
default_mesh: Optional[DeviceMesh] = None,
|
|
) -> Optional[fx.GraphModule]:
|
|
with torch.no_grad():
|
|
# Args should be a list of objects post remapping.
|
|
args = tree_map(partial(_remap_arg, node_to_obj), node.args)
|
|
kwargs = tree_map(partial(_remap_arg, node_to_obj), node.kwargs)
|
|
|
|
op_overload = cast(torch._ops.OpOverload, node.target)
|
|
|
|
if any(a.is_shard() for a in tree_flatten(args)[0] if isinstance(a, DSymInt)):
|
|
if op_overload in VIEW_SYM_INT_CONSUMERS:
|
|
assert len(kwargs) == 0, f"Expect empty kwargs, but got {kwargs}"
|
|
node_to_obj[node] = VIEW_SYM_INT_CONSUMERS[op_overload](node, args)
|
|
return None
|
|
elif op_overload in FACTORY_SYM_INT_CONSUMERS:
|
|
assert default_mesh is not None, "Requires default mesh for factory ops"
|
|
node_to_obj[node] = FACTORY_SYM_INT_CONSUMERS[op_overload](
|
|
node, args, kwargs, default_mesh
|
|
)
|
|
return None
|
|
else:
|
|
assert isinstance(logger, logging.Logger)
|
|
logger.warning(
|
|
"Assuming using local_value from SymInt for %s"
|
|
"is mathematically correct. Full args are %s.",
|
|
op_overload,
|
|
args,
|
|
)
|
|
|
|
if node.target == aten.view.default:
|
|
# HACK: this is a hack to get around with the fact that some
|
|
# view operations on a "global" tensor is invalid usage
|
|
# but somehow the view operation on the batch input might hit it
|
|
# so we convert the view op to reshape before calling DTensor
|
|
op_overload = aten.reshape.default
|
|
|
|
# DSymInt args are not sharded on any dimension, local value and global
|
|
# value should be the same
|
|
args = tree_map(lambda a: a.local_value if isinstance(a, DSymInt) else a, args)
|
|
kwargs = tree_map(
|
|
lambda a: a.local_value if isinstance(a, DSymInt) else a, kwargs
|
|
)
|
|
|
|
if op_overload in FACTORY_OPS:
|
|
# Don't pass factory ops to DTensor dispatch, as DTensor cannot
|
|
# propagate sharding spec without DTensor inputs.
|
|
node_to_obj[node] = FACTORY_OPS[op_overload](
|
|
node, args, kwargs, default_mesh
|
|
)
|
|
return None
|
|
|
|
# run dispatch once to get the real DTensor output.
|
|
out, op_schema, output_sharding = _operator_dispatch(
|
|
op_overload,
|
|
args,
|
|
kwargs, # kwargs in this set of tests are all constants
|
|
DTensor._propagator,
|
|
)
|
|
node_to_obj[node] = out
|
|
|
|
assert output_sharding.schema_suggestions is not None
|
|
target_schema = output_sharding.schema_suggestions[0]
|
|
redistribute = target_schema is not op_schema
|
|
|
|
# If no redistribution is needed, we don't need to replace
|
|
# the original node.
|
|
if not redistribute:
|
|
_update_node_from_op_schema(node, target_schema)
|
|
return None
|
|
|
|
# TODO: this is broken when kwargs contains tensors
|
|
# or if a non-tensor kwarg was modified by the sharding propagation
|
|
# (in order to fix, need to port over pack_args_kwargs_with_local_tensor for kwargs as well)
|
|
updated_args_spec, unflattened_args = _update_specs_for_redistribute(
|
|
args, target_schema, redistribute
|
|
)
|
|
|
|
dispatch = partial(
|
|
_dispatch_with_local_tensors,
|
|
op_overload,
|
|
kwargs=kwargs,
|
|
specs=updated_args_spec,
|
|
)
|
|
|
|
gm = make_fx(dispatch, _allow_non_fake_inputs=False)(unflattened_args)
|
|
# FIXME(@wanchaol, @mrshenli): the above seems to accidentally captured
|
|
# DeviceMesh tensor ops when handling inplace operators? The ``_to_copy`` is
|
|
# not connected to graph output. So, using DCE to get rid of it, but this
|
|
# doesn't look correct.
|
|
#
|
|
# The following operators appear in the captured graph, where the dtype is
|
|
# torch.int64.
|
|
#
|
|
# get_attr _tensor_constant0 _tensor_constant0 ()
|
|
# call_function transpose aten.transpose.int (_tensor_constant0, -1, 0)
|
|
# call_function view aten.view.default (transpose, [-1, 2])
|
|
# call_function view_1 aten.view.default (view, [2])
|
|
# call_function _to_copy aten._to_copy.default (view_1,)
|
|
gm.graph.eliminate_dead_code()
|
|
|
|
return gm
|
|
|
|
|
|
def _build_dummy_add_graph(
|
|
dt: DTensor, node_to_obj: Dict[fx.Node, Any]
|
|
) -> Tuple[fx.GraphModule, Any]:
|
|
"""
|
|
Creates a graph for a dummy add function from a partial DTensor.
|
|
This dummy add is used for triggering all_reduce on a Partial DTensor
|
|
during the DTensor expansion of the traced graph.
|
|
Also returns the actual DTensor after resharding.
|
|
"""
|
|
|
|
def dummy_add(grad: torch.Tensor, zero: torch.Tensor) -> torch.Tensor:
|
|
return grad + zero
|
|
|
|
grad: torch.Tensor = dt._local_tensor
|
|
zero: torch.Tensor = torch.zeros_like(dt._local_tensor)
|
|
|
|
traced_add = make_fx(dummy_add)(grad, zero)
|
|
|
|
placeholders = [n for n in traced_add.graph.nodes if n.op == OP.PLACEHOLDER]
|
|
call_functions = [n for n in traced_add.graph.nodes if n.op == OP.CALL_FUNCTION]
|
|
assert len(placeholders) == 2
|
|
assert len(call_functions) == 1
|
|
node_to_obj[placeholders[0]] = dt
|
|
node_to_obj[placeholders[1]] = DTensor.from_local(
|
|
zero, dt.device_mesh, [Replicate()], run_check=False
|
|
)
|
|
|
|
traced_dispatch = _get_dtensor_dispatch_graph(
|
|
call_functions[0], node_to_obj, force_make_fx=True
|
|
)
|
|
assert traced_dispatch is not None
|
|
|
|
# TODO(anj): This depends on the call function node -> actual DTensor output
|
|
# mapping that we want to avoid for SPMD expansion
|
|
return traced_dispatch, node_to_obj[call_functions[0]]
|
|
|
|
|
|
def _convert_output(
|
|
gm: fx.GraphModule,
|
|
node: fx.Node,
|
|
node_to_obj: Dict[fx.Node, Any],
|
|
) -> fx.Node:
|
|
new_args = []
|
|
has_partial = False
|
|
for argument in node.args[0]: # type: ignore[union-attr]
|
|
if not isinstance(argument, fx.Node):
|
|
new_args.append(argument)
|
|
continue
|
|
|
|
obj = node_to_obj[argument]
|
|
|
|
if not _is_partial_dtensor(obj):
|
|
new_args.append(argument)
|
|
continue
|
|
|
|
has_partial = True
|
|
|
|
# we know it's a dtensor from is partial DT check...
|
|
dt = cast(DTensor, obj)
|
|
|
|
traced_dispatch, result_obj = _build_dummy_add_graph(dt, node_to_obj)
|
|
|
|
wait = [
|
|
n
|
|
for n in traced_dispatch.graph.nodes
|
|
if n.name == "wait_comm" or n.name == "wait_tensor"
|
|
]
|
|
add = [n for n in traced_dispatch.graph.nodes if n.name == "add"]
|
|
assert len(wait) == 1 and len(add) == 1
|
|
|
|
# remove add node and replace it with wait node
|
|
add[0].replace_all_uses_with(wait[0])
|
|
traced_dispatch.graph.eliminate_dead_code()
|
|
# also update the actual DTensor corresponding to the node
|
|
# TODO(anj): We require mapping of the final DTensor output to the wait
|
|
# comm node.
|
|
node_to_obj[wait[0]] = result_obj
|
|
|
|
value_remap: Dict[fx.Node, fx.Node] = {}
|
|
for dtn in traced_dispatch.graph.nodes:
|
|
if dtn.op == OP.PLACEHOLDER:
|
|
# do nothing, ignore placeholders, as it has
|
|
# already been prepared in value_remap
|
|
value_remap[dtn] = argument
|
|
elif dtn.op == OP.OUTPUT:
|
|
assert (
|
|
len(dtn.args) == 1 and len(dtn.args[0]) == 1
|
|
), f"Expecting single output, but got {dtn.args} {len(dtn.args)}"
|
|
new_args.append(value_remap[dtn.args[0][0]])
|
|
# the concrete DTensor value of output was added when creating the
|
|
# inner graph (in _build_dummy_add_graph). Just add it to the final
|
|
# output node so that we can report the final output specs correctly.
|
|
# TODO(anj): We are depending on the concrete DTensor output of the dummy add.
|
|
node_to_obj[value_remap[dtn.args[0][0]]] = node_to_obj[dtn.args[0][0]]
|
|
|
|
else:
|
|
if dtn.op == OP.GET_ATTR:
|
|
setattr(
|
|
gm,
|
|
dtn.target,
|
|
getattr(traced_dispatch, dtn.target),
|
|
)
|
|
with gm.graph.inserting_before(node):
|
|
value_remap[dtn] = gm.graph.node_copy(dtn, lambda n: value_remap[n])
|
|
if has_partial:
|
|
gm.graph.erase_node(node)
|
|
return gm.graph.output(new_args)
|
|
else:
|
|
return node
|
|
|
|
|
|
def _rebuild_graph(
|
|
gm: fx.GraphModule,
|
|
node_replacements: Dict[torch.fx.Node, torch.fx.GraphModule],
|
|
) -> None:
|
|
# replace nodes in local traced graph with DTensor's dispatch graph
|
|
for node in gm.graph.nodes:
|
|
if node not in node_replacements:
|
|
continue
|
|
|
|
traced_dispatch = node_replacements[node]
|
|
# Map DT's dispatch graph input placeholder nodes to the ones in
|
|
# local traced graph. It uses index-based accessing, which is
|
|
# brittle, just for testing purpose.
|
|
flatten_args, _ = tree_flatten(node.args)
|
|
i, value_remap = 0, {}
|
|
for dtn in traced_dispatch.graph.nodes:
|
|
if dtn.op == OP.PLACEHOLDER:
|
|
value_remap[dtn] = flatten_args[i]
|
|
i += 1
|
|
|
|
# insert DT's dispatch graph to traced local graph.
|
|
with gm.graph.inserting_before(node):
|
|
for dtn in traced_dispatch.graph.nodes:
|
|
if dtn.op == OP.PLACEHOLDER:
|
|
# do nothing, ignore placeholders, as it has already
|
|
# been prepared in value_remap
|
|
pass
|
|
elif dtn.op == OP.OUTPUT:
|
|
assert (
|
|
len(dtn.args) == 1
|
|
), f"Expecting single output, but got {dtn.args} {len(dtn.args[0])}"
|
|
outputs = dtn.args[0]
|
|
# we currently support two very specific types of output
|
|
# 1. single output
|
|
# 2. multiple outputs resulting from getitem of all elements of tuple
|
|
if len(outputs) == 1:
|
|
# for single output, we replace the node with the single node
|
|
output = outputs[0]
|
|
else:
|
|
# for multiple outputs, we check that these outputs correspond
|
|
# to all elements of a tuple. In that case, we replace
|
|
# uses of the output directly with the original tuple
|
|
source = None
|
|
for i, out in enumerate(outputs):
|
|
# we allow None outputs for certain items in the tuple
|
|
if out is None:
|
|
continue
|
|
assert out.op == "call_function"
|
|
assert out.target.__module__ == "_operator"
|
|
assert out.target.__name__ == "getitem"
|
|
assert source is None or source == out.args[0]
|
|
source = out.args[0]
|
|
assert out.args[1] == i
|
|
assert source is not None
|
|
output = source
|
|
|
|
new_node = value_remap[output]
|
|
node.replace_all_uses_with(new_node)
|
|
else:
|
|
value_remap[dtn] = gm.graph.node_copy(dtn, lambda n: value_remap[n])
|
|
if all(
|
|
isinstance(n.target, torch._ops.OpOverload)
|
|
and n.target._schema.name.startswith(
|
|
("aten::_foreach", "aten::_fused_adam")
|
|
)
|
|
for n in [dtn, node]
|
|
):
|
|
# FIXME(@mrshenli): This is a temporary solution enable
|
|
# foreach ops. The problem is that foreach ops returns
|
|
# List[Tensor], but make_fx will flatten that before
|
|
# passing those tensors to output node, which will
|
|
# introduce additional getitem nodes. These redundant
|
|
# getitem nodes breaks graph correctness as we cannot do
|
|
# getitem(getitem(foreach_out, 0), 0). This temporary
|
|
# solution skips getitem nodes in DTensor expanded
|
|
# subgraphs.
|
|
node.replace_all_uses_with(value_remap[dtn])
|
|
break
|
|
# explicitly erase node instead of relying on DCE, as DCE does not
|
|
# remove inplace copy_ correctly.
|
|
gm.graph.erase_node(node)
|
|
|
|
gm.graph.eliminate_dead_code()
|
|
gm.recompile()
|
|
|
|
|
|
def _get_last_consumer_to_nodes(
|
|
graph: fx.Graph,
|
|
) -> Dict[fx.Node, List[fx.Node]]:
|
|
# Run through reverse nodes and record the first instance of a use
|
|
# of a given node. This represents the *last* use of the node in the
|
|
# execution order of the program, which we will use to free unused
|
|
# values
|
|
node_to_last_consumer: Dict[fx.Node, fx.Node] = {}
|
|
last_consumer_to_nodes: Dict[fx.Node, List[fx.Node]] = {}
|
|
|
|
def _register_final_consumer(arg_node: fx.Node, consumer: fx.Node) -> None:
|
|
if arg_node not in node_to_last_consumer:
|
|
node_to_last_consumer[arg_node] = consumer
|
|
last_consumer_to_nodes.setdefault(consumer, []).append(arg_node)
|
|
|
|
for node in reversed(graph.nodes):
|
|
fx.node.map_arg(
|
|
node.args, lambda arg_node: _register_final_consumer(arg_node, node)
|
|
)
|
|
fx.node.map_arg(
|
|
node.kwargs,
|
|
lambda kwarg_node: _register_final_consumer(kwarg_node, node),
|
|
)
|
|
|
|
return last_consumer_to_nodes
|
|
|
|
|
|
def _convert_to_distributed(
|
|
gm: fx.GraphModule,
|
|
inps: List[torch.Tensor],
|
|
schemas: List[Schema],
|
|
default_mesh: Optional[DeviceMesh] = None,
|
|
_allow_partial: bool = False,
|
|
) -> Tuple[fx.GraphModule, Dict[str, Schema]]:
|
|
"""
|
|
Returns:
|
|
- transformed graph module
|
|
- map from output name to DTensorSpec
|
|
"""
|
|
global logger
|
|
logger = get_logger("spmd_exp")
|
|
operators = {getattr(operator, name) for name in operator.__all__}
|
|
node_to_obj: Dict[fx.Node, Any] = {}
|
|
# map local op node in traced_f to its corresponding subgraph of
|
|
# DTensor ops.
|
|
node_replacements: Dict[torch.fx.Node, torch.fx.GraphModule] = {}
|
|
|
|
last_consumer_to_nodes = _get_last_consumer_to_nodes(gm.graph)
|
|
|
|
output_schemas: Dict[str, Schema] = {}
|
|
for i, node in enumerate(gm.graph.nodes):
|
|
assert logger is not None
|
|
logger.info("node%s: op=%s target=%s", i, node.op, node.target)
|
|
if node.op == OP.PLACEHOLDER:
|
|
assert i < len(
|
|
inps
|
|
), f"got more placeholder nodes ({i + 1}) than inputs ({len(inps)})"
|
|
|
|
# our example inputs are local shards. Create DTensors from them.
|
|
node_to_obj[node] = DTensor.from_local(
|
|
inps[i].clone(), # use clone to avoid modifications from inplace ops
|
|
schemas[i].mesh,
|
|
schemas[i].placements,
|
|
# prevent running this collective in backwards pass
|
|
run_check=False,
|
|
)
|
|
elif isinstance(node.target, torch._ops.OpOverloadPacket):
|
|
dtensor = cast(DTensor, node_to_obj[node.args[0]])
|
|
node_to_obj[node] = DSymInt.from_node(node, dtensor)
|
|
elif isinstance(node.target, torch._ops.OpOverload):
|
|
replacement = _get_dtensor_dispatch_graph(
|
|
node, node_to_obj, default_mesh=default_mesh
|
|
)
|
|
if replacement is not None:
|
|
node_replacements[node] = replacement
|
|
elif node.op == OP.OUTPUT:
|
|
if not _allow_partial:
|
|
# Returns an expanded dummy add node that ensures
|
|
# that the partial output tensor has been converted
|
|
# to a replicated tensor.
|
|
node = _convert_output(gm, node, node_to_obj)
|
|
|
|
# Save output sharding for the inputs to backward pass.
|
|
# TODO(anj): Pipe the output schema for the BW pass
|
|
# instead of requiring the full output DTensor to be
|
|
# materialized.
|
|
for inp_arg in node.args[0]:
|
|
if isinstance(inp_arg, fx.Node):
|
|
obj = node_to_obj[inp_arg]
|
|
if isinstance(obj, DTensor):
|
|
output_schemas[inp_arg.name] = Schema(
|
|
obj.device_mesh, obj.placements # type: ignore[arg-type]
|
|
)
|
|
elif node.op == OP.CALL_FUNCTION:
|
|
args = tree_map(partial(_remap_arg, node_to_obj), node.args)
|
|
kwargs = tree_map(partial(_remap_arg, node_to_obj), node.kwargs)
|
|
|
|
dsymints = list(
|
|
filter(lambda a: isinstance(a, DSymInt), args + tuple(kwargs.values()))
|
|
)
|
|
|
|
if node.target in operators and len(dsymints) > 0:
|
|
assert all(
|
|
dsymints[0].mesh == d.mesh for d in dsymints
|
|
), "all DSymInts must have the same mesh. "
|
|
|
|
local_args = tree_map_only(DSymInt, lambda a: a.local_value, args)
|
|
local_kwargs = tree_map_only(DSymInt, lambda a: a.local_value, kwargs)
|
|
|
|
global_args = tree_map_only(DSymInt, lambda a: a.global_value, args)
|
|
global_kwargs = tree_map_only(DSymInt, lambda a: a.global_value, kwargs)
|
|
|
|
node.args = local_args
|
|
node.kwargs = local_kwargs
|
|
|
|
node_to_obj[node] = DSymInt(
|
|
local_value=node.target(*local_args, **local_kwargs),
|
|
global_value=node.target(*global_args, **global_kwargs),
|
|
mesh=dsymints[0].mesh,
|
|
)
|
|
else:
|
|
assert len(dsymints) == 0, (
|
|
"SPMD expansion does not support SymInt in non-operator "
|
|
f"nodes, got {node.target}."
|
|
)
|
|
node_to_obj[node] = node.target(*args, **kwargs)
|
|
else:
|
|
raise ValueError(f"Unrecognized node.op type {node.op}")
|
|
|
|
if node in last_consumer_to_nodes:
|
|
# Save memory by deleting objs that wont be used anymore.
|
|
for arg_node in last_consumer_to_nodes[node]:
|
|
del node_to_obj[arg_node]
|
|
|
|
_rebuild_graph(gm, node_replacements)
|
|
|
|
return gm, output_schemas
|
|
|
|
|
|
class _SPMD:
|
|
def __init__(
|
|
self,
|
|
dist_graph: DistributedGraph,
|
|
param_schema: Schema,
|
|
input_schemas: Sequence[Placement],
|
|
) -> None:
|
|
self._dist_graph = dist_graph
|
|
self._param_schema = param_schema
|
|
# Override the default sharding of input to the model.
|
|
self._input_schemas = input_schemas
|
|
# used to propagate sharding from the output of the forward pass to
|
|
# the input of backward pass
|
|
self._known_specs_by_node_name: Dict[str, Schema] = {}
|
|
|
|
def _is_param(self, t: torch.Tensor) -> bool:
|
|
# N.B.: id(t) and id(param) does not match
|
|
orig_module = cast(nn.Module, self._dist_graph.orig_module)
|
|
return t.data_ptr() in (p.data_ptr() for p in orig_module.parameters())
|
|
|
|
def _compile_wrapper(
|
|
self,
|
|
training_phase: TrainingPhase,
|
|
original_inputs: List[List[torch.Tensor]],
|
|
gm: fx.GraphModule,
|
|
inps: List[torch.Tensor],
|
|
) -> fx.GraphModule:
|
|
with maybe_disable_fake_tensor_mode():
|
|
return self._compile(training_phase, gm, original_inputs[0])
|
|
|
|
def _compile(
|
|
self,
|
|
training_phase: TrainingPhase,
|
|
gm: fx.GraphModule,
|
|
inps: List[torch.Tensor],
|
|
) -> fx.GraphModule:
|
|
shard_schema: Schema = Schema(
|
|
mesh=self._param_schema.mesh, placements=[Shard(0)]
|
|
)
|
|
schemas: List[Schema] = []
|
|
inp_schema_count = 0
|
|
nparams = 0
|
|
|
|
# iterate through inputs (and initial nodes of the graph that should
|
|
# correspond 1:1 to those inputs)
|
|
for inp, placeholder_node in zip(inps, gm.graph.nodes):
|
|
# This is a no-op but we want the order of schemas
|
|
# to match the order of inputs when we iterate through
|
|
# the graph. Usually the non-tensor inputs are at the
|
|
# end of the list so we could drop the schemas for it.
|
|
|
|
assert placeholder_node.op == "placeholder", (
|
|
"Expected initial nodes of the GraphModule to be input placeholders. "
|
|
"Got {placeholder_node.op}"
|
|
)
|
|
|
|
known_schema = self._known_specs_by_node_name.get(placeholder_node.name)
|
|
|
|
if known_schema is not None:
|
|
schemas.append(known_schema)
|
|
elif not isinstance(inp, torch.Tensor):
|
|
schemas.append(
|
|
Schema(mesh=self._param_schema.mesh, placements=[Replicate()])
|
|
)
|
|
else:
|
|
if self._is_param(inp):
|
|
schemas.append(self._param_schema)
|
|
nparams += 1
|
|
elif self._input_schemas:
|
|
schemas.append(self._input_schemas[inp_schema_count]) # type: ignore[arg-type]
|
|
inp_schema_count += 1
|
|
else:
|
|
schemas.append(shard_schema)
|
|
|
|
parallelized_gm, output_specs = _convert_to_distributed(
|
|
gm,
|
|
inps,
|
|
schemas,
|
|
_allow_partial=False,
|
|
)
|
|
self._known_specs_by_node_name.update(output_specs)
|
|
|
|
if training_phase == TrainingPhase.FORWARD:
|
|
self._dist_graph.fwd_graph_modules.append(parallelized_gm)
|
|
elif training_phase == TrainingPhase.BACKWARD:
|
|
self._dist_graph.bwd_graph_modules.append(parallelized_gm)
|
|
return make_boxed_func(parallelized_gm)
|
|
|
|
|
|
def distribute(
|
|
dist_graph: DistributedGraph,
|
|
param_schema: Schema,
|
|
input_schemas: Sequence[Placement],
|
|
*args: Tuple[Any, ...],
|
|
**kwargs: Dict[str, Any],
|
|
) -> nn.Module:
|
|
flat_args, _ = tree_flatten(args)
|
|
flat_kwargs, _ = tree_flatten(kwargs)
|
|
input_set: Set[Any] = set(flat_args + flat_kwargs)
|
|
|
|
fake_mode: FakeTensorMode = detect_fake_mode(input_set)
|
|
if fake_mode is None:
|
|
fake_mode = FakeTensorMode()
|
|
|
|
# will update this to the original forward inputs
|
|
original_inputs: List[Optional[Sequence[Any]]] = [None]
|
|
|
|
def input_to_fake(input: Any) -> Any:
|
|
if not isinstance(input, torch.Tensor):
|
|
return input
|
|
y = fake_mode.from_tensor(input)
|
|
if input in input_set:
|
|
# "unshard" our fake tensor
|
|
# (considers that inputs are sharded)
|
|
y = y.repeat(param_schema.mesh.size(0), *((1,) * (y.ndim - 1)))
|
|
# TODO assume non-inputs (params, etc) are replicated for now.
|
|
return y
|
|
|
|
def gather_inputs_for_compilation(
|
|
inps: Tuple[Any, ...],
|
|
) -> Tuple[Any, ...]:
|
|
original_inputs[0] = inps
|
|
return tuple(input_to_fake(x) for x in inps)
|
|
|
|
spmd = _SPMD(dist_graph, param_schema, input_schemas)
|
|
compiled_m = aot_module(
|
|
cast(nn.Module, dist_graph.orig_module),
|
|
partial(spmd._compile_wrapper, TrainingPhase.FORWARD, original_inputs),
|
|
partial(spmd._compile, TrainingPhase.BACKWARD),
|
|
pre_compile_fn=gather_inputs_for_compilation,
|
|
decompositions=_CURRENT_DECOMPOSITION_TABLE,
|
|
)
|
|
|
|
return compiled_m
|