mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This PR is the first change of a series of refactors to the op dispatch logic to: 1. remove the redundant logic in the op dispatch, simplify the error checking 2. reduce the number of tree_map/tree_flatten/unflatten needed to reduce the overhead coming from those operations 3. remove the CachedShardingPropagator by using lru_cache from functools directly, this makes it not only helps TP, but general DTensor operations could be faster! 4. change the view ops behavior by inplace changing the op_schema, which is dangerous for sharding prop caching, model the view op as one type of resharding too 5. enrich output sharding to include whether the op needs redistribute so that we don't need explicit op schema comparison to know it. This should help with further reducing the CPU overhead, benchmark results: before (without this change), aten.addmm latency: 0.476ms  after (with this change), aten.addmm latency: 0.341ms  overall one layer of mlp time reduced from 13.535 -> 9.665ms Apart from overhead reduction, this PR simplifies the op dispatching logic and the resharding logic (more refactor needed to make things more clean, which will be done in later PRs) Pull Request resolved: https://github.com/pytorch/pytorch/pull/107305 Approved by: https://github.com/fduwjj
845 lines
37 KiB
Python
845 lines
37 KiB
Python
import operator
|
|
from contextlib import contextmanager
|
|
from enum import Enum
|
|
|
|
from typing import Any, cast, Dict, List, Optional, Tuple
|
|
|
|
import torch
|
|
|
|
import torch.distributed.distributed_c10d as c10d
|
|
import torch.fx as fx
|
|
import torch.library
|
|
import torch.nn as nn
|
|
|
|
import torch.utils._pytree as pytree
|
|
|
|
from torch.distributed._spmd.batch_dim_utils import BatchDimAnalyzer
|
|
from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard
|
|
|
|
from torch.distributed._tensor._utils import compute_local_shape
|
|
from torch.distributed._tensor.op_schema import (
|
|
OpStrategy,
|
|
PlacementStrategy,
|
|
StrategyType,
|
|
TupleStrategy,
|
|
)
|
|
from torch.distributed._tensor.placement_types import _Partial, DTensorSpec, Placement
|
|
from torch.distributed._tensor.redistribute import redistribute_local_tensor
|
|
from torch.fx import GraphModule
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
|
from torch.nn.utils._named_member_accessor import NamedMemberAccessor
|
|
|
|
aten = torch.ops.aten
|
|
|
|
# Dummy op used by data parallel to tag gradients.
|
|
_spmd_lib_def = torch.library.Library("_spmd", "DEF")
|
|
_spmd_lib_def.define("tag_grad(Tensor self) -> Tensor")
|
|
|
|
_spmd_lib_impl = torch.library.Library("_spmd", "IMPL")
|
|
_spmd_lib_impl.impl("tag_grad", lambda x: x, "CompositeExplicitAutograd")
|
|
|
|
|
|
class DataParallelStyle(Enum):
|
|
"""
|
|
We have three types of Data Parallel style:
|
|
1. DEFAULT: the default data parallel style, which is to represent a mixed
|
|
replicate and fully shard behavior. For each parameter that is able
|
|
to be sharded evenly, we shard it, otherwise we would replicate the
|
|
parameter. This style avoids potential padding if the parameters
|
|
cannot be sharded evenly, but it would generate a mixed of all_reduce
|
|
and reduce_scatter.
|
|
2. REPLICATE: the data parallel style that replicates all model parameters.
|
|
This is similar to the behavior of DistributedDataParallel.
|
|
3. FULLY_SHARD: the data parallel style that shards all model parameters. This
|
|
is similar to the behavior of FullyShardedDataParallel, the
|
|
difference is that FullyShardedDataParallel (ZERO-3), which
|
|
shards the model using FlatParameter based sharding,
|
|
while this style shards each parameter into DTensor.
|
|
"""
|
|
|
|
DEFAULT = 0
|
|
REPLICATE = 1
|
|
FULLY_SHARD = 2
|
|
|
|
|
|
class NodeType(Enum):
|
|
"""
|
|
NodeType is a enum that records the type of the tensors in the graph.
|
|
This is used to determine the data parallel strategy.
|
|
"""
|
|
|
|
PARAM = 0
|
|
ACT = 1
|
|
GRAD = 2
|
|
STATE = 3
|
|
NON_TENSOR = 4 # NON_TENSOR is to tag non tensor node (i.e. graph output)
|
|
|
|
|
|
class DataParallelStrategy(OpStrategy):
|
|
"""
|
|
DataParallelStrategy is a special case of OpStrategy that only records
|
|
the "data parallel style" placement strategy for each fx Node.
|
|
|
|
It takes a list of PlacementStrategy, where each PlacementStrategy describes
|
|
one way to distribute the tensor and computation. In the DataParallel case,
|
|
there're two possible ways to distribute the parameters:
|
|
1. replicate the parameter over a set of devices (DDP like behavior)
|
|
2. shard the parameter on its tensor dimension 0 over a set of devices
|
|
(FSDP like behavior).
|
|
|
|
In addition to the strategy list, we also need to:
|
|
1. `node_type`: record the type of each node in the graph, so that we can
|
|
determine how to propagate in a data parallel fashion.
|
|
2. `reduce_over_batch` is specifically tied to data parallel as the loss
|
|
calculation usually results in scalar tensor where it comes from a
|
|
reduction over the batch dimension. We need to know this information
|
|
so that we could keep the output as sharded.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
node_type: NodeType,
|
|
strategy_list: List[PlacementStrategy],
|
|
reduction_over_batch: bool = False,
|
|
):
|
|
super().__init__(strategy_list)
|
|
self.node_type = node_type
|
|
self.reduction_over_batch = reduction_over_batch
|
|
|
|
def __str__(self) -> str:
|
|
return f"type: {self.node_type}, {super().__str__()}"
|
|
|
|
|
|
@contextmanager
|
|
def gradients_tagging(params: Dict[str, torch.Tensor]):
|
|
"""
|
|
This is a helper function that tags the gradient of the parameters
|
|
with a special tag, so that we can identify them during SPMD expansion.
|
|
|
|
It's safe to trace those hooks and we would remove those nodes later.
|
|
"""
|
|
|
|
tagging_hooks = []
|
|
try:
|
|
for p in params.values():
|
|
h = p.register_hook(lambda grad: torch.ops._spmd.tag_grad(grad))
|
|
tagging_hooks.append(h)
|
|
yield
|
|
finally:
|
|
# remove those hooks after tracing
|
|
for h in tagging_hooks:
|
|
h.remove()
|
|
|
|
|
|
def _gen_shard_strategy(
|
|
mesh: DeviceMesh, shard_dim: int, input_specs: Optional[List[DTensorSpec]] = None
|
|
) -> PlacementStrategy:
|
|
"""
|
|
util function to generate a shard strategy on shard_dim
|
|
"""
|
|
return PlacementStrategy(
|
|
output_spec=DTensorSpec(mesh=mesh, placements=(Shard(shard_dim),)),
|
|
input_specs=input_specs,
|
|
)
|
|
|
|
|
|
def _gen_replicate_strategy(
|
|
mesh: DeviceMesh, input_specs: Optional[List[DTensorSpec]] = None
|
|
) -> PlacementStrategy:
|
|
"""
|
|
util function to generate a replicate strategy
|
|
"""
|
|
return PlacementStrategy(
|
|
output_spec=DTensorSpec(mesh=mesh, placements=(Replicate(),)),
|
|
input_specs=input_specs,
|
|
)
|
|
|
|
|
|
def _gen_partial_strategy(mesh: DeviceMesh) -> PlacementStrategy:
|
|
"""
|
|
util function to generate a partial strategy
|
|
"""
|
|
# NOTE: we use AVG by default, avg reduction is needed depending on
|
|
# the loss function, for most loss function it should do
|
|
# gradient averaging. There might be certain cases it should
|
|
# not do gradient averaging (i.e. sum) but it's pretty rare.
|
|
# TODO: Only NCCL supports AVG so using backend like Gloo would
|
|
# crash, we should figure out a way to support avg reduction
|
|
# for non-NCCL backend
|
|
reduce_op = c10d.ReduceOp.AVG # type: ignore[attr-defined]
|
|
return PlacementStrategy(
|
|
output_spec=DTensorSpec(mesh=mesh, placements=(_Partial(reduce_op),)),
|
|
)
|
|
|
|
|
|
def build_data_parallel_strategies(
|
|
train_step_graph: GraphModule,
|
|
num_params: int,
|
|
num_states: int,
|
|
mesh: DeviceMesh,
|
|
batch_dim: int = 0,
|
|
) -> Dict[fx.Node, StrategyType]:
|
|
"""
|
|
This function loop through the train step graph and build the
|
|
data parallel strategy for each fx Node
|
|
"""
|
|
activation_idx = num_params + num_states
|
|
non_compute_ops = [
|
|
aten.clone.default,
|
|
aten.detach.default,
|
|
aten.ones_like.default,
|
|
aten.reshape.default,
|
|
aten.t.default,
|
|
aten.view.default,
|
|
torch.ops._spmd.tag_grad.default,
|
|
operator.getitem,
|
|
]
|
|
|
|
tuple_strategy_ops = [aten._fused_adam.default]
|
|
|
|
dp_strategy_map: Dict[fx.Node, StrategyType] = {}
|
|
batch_dim_analyzer = BatchDimAnalyzer(batch_dim)
|
|
placeholder_idx = 0
|
|
num_param_grad = 0
|
|
|
|
# first we backward propagate to mark the param gradients sharding
|
|
# with tag_grad node helps and then delete the tag_grad nodes
|
|
for node in reversed(list(train_step_graph.graph.nodes)):
|
|
# find a param_grad node via the tagging
|
|
if node.target == torch.ops._spmd.tag_grad.default:
|
|
cur_node = node
|
|
while cur_node.target in non_compute_ops:
|
|
cur_node = cur_node.args[0]
|
|
partial_strategy = _gen_partial_strategy(mesh)
|
|
dp_strategy_map[cur_node] = DataParallelStrategy(
|
|
NodeType.GRAD, [partial_strategy]
|
|
)
|
|
num_param_grad += 1
|
|
# remove the tag_grad node from graph
|
|
node.replace_all_uses_with(node.args[0])
|
|
train_step_graph.graph.erase_node(node)
|
|
|
|
if num_param_grad == num_params:
|
|
# early break if we have already processed all param_grads
|
|
break
|
|
|
|
# next we forward propagate to mark all the sharding
|
|
for node in train_step_graph.graph.nodes:
|
|
if node.op == "placeholder":
|
|
if "val" not in node.meta:
|
|
# NOTE: There're certain cases where the placeholder nodes do
|
|
# not have real tensor values:
|
|
# 1. optimizer states can be None sometimes, i.e. SGD with
|
|
# no momentum, optimizer states populate `momentum` state
|
|
# as None, the full graph we get from `compile` would have
|
|
# None as the placeholder value
|
|
# 2. function args might not only contain params or activations,
|
|
# but also contain other non-tensor inputs, i.e. the model
|
|
# and optimizer instances baked in as a placeholder, there might
|
|
# also be some scalar argument which is not a tensor
|
|
#
|
|
# For the above cases, we create a NON_TENSOR stratgy so that we
|
|
# know it's not a tensor and we don't need to shard it
|
|
dp_strategy_map[node] = DataParallelStrategy(NodeType.NON_TENSOR, [])
|
|
|
|
elif placeholder_idx < num_params:
|
|
# during compilation there's an assumption that the first num_params
|
|
# placeholders should be parameters
|
|
shard_strategy = _gen_shard_strategy(mesh, 0)
|
|
replica_strategy = _gen_replicate_strategy(mesh)
|
|
dp_strategy_map[node] = DataParallelStrategy(
|
|
NodeType.PARAM, [replica_strategy, shard_strategy]
|
|
)
|
|
|
|
elif placeholder_idx < activation_idx:
|
|
# optimizer states follow the same strategy as
|
|
# the corresponding parameters
|
|
replica_strategy = _gen_replicate_strategy(mesh)
|
|
shard_strategy = _gen_shard_strategy(mesh, 0)
|
|
|
|
dp_strategy_map[node] = DataParallelStrategy(
|
|
NodeType.STATE, [replica_strategy, shard_strategy]
|
|
)
|
|
else:
|
|
activation_batch_dim_size = node.meta["val"].shape[batch_dim]
|
|
# find the first activation node and use its batch dim size
|
|
if batch_dim_analyzer.batch_dim_size == -1:
|
|
batch_dim_analyzer.init_batch_dim_size(activation_batch_dim_size)
|
|
|
|
batch_dim_analyzer.set_batch_dim(node, batch_dim)
|
|
shard_strategy = _gen_shard_strategy(mesh, batch_dim)
|
|
dp_strategy_map[node] = DataParallelStrategy(
|
|
NodeType.ACT, [shard_strategy]
|
|
)
|
|
placeholder_idx += 1
|
|
elif node.op == "call_function":
|
|
# Annotate node types for the computation graph
|
|
# Data Parallel node propagation logic:
|
|
# param (non-compute) -> out: param
|
|
# grad (non-compute before/after) -> out: grad
|
|
# state -> output: state
|
|
#
|
|
# param + activation (param must be replicate, act be sharded) -> out: activation
|
|
# param/state + grad (param/state/grad be the same spec) -> out: param/state
|
|
# param + state -> out: param
|
|
|
|
if node.target in non_compute_ops:
|
|
# At this point, we should have removed all the `tag_grad` nodes in the graph
|
|
assert node.target != torch.ops._spmd.tag_grad.default
|
|
|
|
input_nodes = node.all_input_nodes
|
|
assert (
|
|
len(input_nodes) == 1
|
|
), f"non-compute op only support one input now, found node: {node} with length of inputs: {len(node.args)}"
|
|
arg_strategy = dp_strategy_map[input_nodes[0]]
|
|
|
|
if node.target == operator.getitem:
|
|
# for getitem call, just forward the strategy from the input
|
|
getitem_idx = node.args[1]
|
|
if isinstance(arg_strategy, TupleStrategy):
|
|
# for tuple strategy, we need to get the child strategy from the tuple
|
|
dp_strategy_map[node] = arg_strategy.childs[getitem_idx]
|
|
else:
|
|
# if it's not a tuple strategy, we just forward the arg strategy
|
|
dp_strategy_map[node] = arg_strategy
|
|
else:
|
|
assert isinstance(arg_strategy, DataParallelStrategy)
|
|
arg_node_type = arg_strategy.node_type
|
|
if arg_node_type == NodeType.PARAM:
|
|
replica_strategy = _gen_replicate_strategy(mesh)
|
|
dp_strategy_map[node] = DataParallelStrategy(
|
|
NodeType.PARAM, [replica_strategy]
|
|
)
|
|
elif arg_node_type == NodeType.GRAD:
|
|
partial_sig = _gen_partial_strategy(mesh)
|
|
dp_strategy_map[node] = DataParallelStrategy(
|
|
NodeType.GRAD, [partial_sig]
|
|
)
|
|
elif arg_node_type == NodeType.ACT:
|
|
arg_node_spec = batch_dim_analyzer.compute_act_spec(
|
|
input_nodes[0], mesh
|
|
)
|
|
|
|
output_spec = batch_dim_analyzer.compute_act_spec(node, mesh)
|
|
|
|
shard_strategy = PlacementStrategy(
|
|
output_spec=output_spec, input_specs=[arg_node_spec]
|
|
)
|
|
dp_strategy_map[node] = DataParallelStrategy(
|
|
NodeType.ACT, [shard_strategy]
|
|
)
|
|
else:
|
|
raise RuntimeError(
|
|
f"non compute op not supporting {arg_node_type}! "
|
|
)
|
|
|
|
# finished processing this non-compute node
|
|
continue
|
|
|
|
# for computatation nodes, we need to check all the inputs
|
|
input_args = node.all_input_nodes
|
|
input_specs = []
|
|
if node in dp_strategy_map:
|
|
# found a param_grad node that already have output pre-filled spec
|
|
# fill in the expected input specs for the pre-filled strategy
|
|
node_strategy = dp_strategy_map[node]
|
|
assert isinstance(node_strategy, DataParallelStrategy)
|
|
node_type = node_strategy.node_type
|
|
assert node_type == NodeType.GRAD
|
|
produce_param_grad_strat = node_strategy.strategies
|
|
has_activation = False
|
|
for arg in input_args:
|
|
arg_strategy = dp_strategy_map[arg]
|
|
assert isinstance(arg_strategy, DataParallelStrategy)
|
|
arg_node_type = arg_strategy.node_type
|
|
if arg_node_type == NodeType.ACT:
|
|
# activation sharded
|
|
has_activation = True
|
|
act_spec = batch_dim_analyzer.compute_act_spec(arg, mesh)
|
|
|
|
input_specs.append(act_spec)
|
|
|
|
if has_activation:
|
|
assert len(produce_param_grad_strat) == 1
|
|
produce_param_grad_strat[0].input_specs = input_specs
|
|
elif node.target in tuple_strategy_ops:
|
|
# ops that need to build tuple strategy instead of normal strategy
|
|
# This should happen rarely and only needed when we need to generate
|
|
# different node strategy for multiple outputs (i.e. fused_adam op)
|
|
# TODO: Currently this specializes to fused optimizer ops, but we need
|
|
# to see how to generalize this strategy building logic
|
|
output_strategy_len = len(node.args) - 1
|
|
tuple_strategies = []
|
|
for i in range(output_strategy_len):
|
|
if not isinstance(node.args[i], list):
|
|
raise RuntimeError(
|
|
f"Expecting list as arg to build Tuple Strategy, but found type {type(node.args[i])}!"
|
|
)
|
|
# for list/tuple arg, use the first one to find out the node type
|
|
if len(node.args[i]) > 0:
|
|
arg_strategy = dp_strategy_map[node.args[i][0]]
|
|
assert isinstance(arg_strategy, DataParallelStrategy)
|
|
assert arg_strategy.node_type in [
|
|
NodeType.PARAM,
|
|
NodeType.GRAD,
|
|
NodeType.STATE,
|
|
], "Expecting param/grad/state as arg to build Tuple Strategy!"
|
|
replica_strategy = _gen_replicate_strategy(mesh)
|
|
shard_strategy = _gen_shard_strategy(mesh, shard_dim=0)
|
|
out_node_strategy: StrategyType = DataParallelStrategy(
|
|
arg_strategy.node_type, [replica_strategy, shard_strategy]
|
|
)
|
|
|
|
tuple_strategies.append(out_node_strategy)
|
|
|
|
output_tuple_strategy = TupleStrategy(tuple(tuple_strategies))
|
|
dp_strategy_map[node] = output_tuple_strategy
|
|
else:
|
|
# NOTE: This is the common region for all regular computation ops
|
|
|
|
input_node_types = [
|
|
cast(DataParallelStrategy, dp_strategy_map[arg]).node_type
|
|
for arg in input_args
|
|
if isinstance(dp_strategy_map[arg], DataParallelStrategy)
|
|
]
|
|
if NodeType.GRAD in input_node_types:
|
|
# param/state + grad, build up acceptable strategy
|
|
# the strategy should be the same for all the inputs/outputs
|
|
# TODO: optimizer parts should follow the dtensor prop logic
|
|
# to support more general cases that allows optimizer states
|
|
# to have different shardings compare to the params
|
|
replica_strategy = _gen_replicate_strategy(mesh)
|
|
shard_strategy = _gen_shard_strategy(mesh, shard_dim=0)
|
|
output_node_type = NodeType.PARAM
|
|
|
|
non_grad_types = [t for t in input_node_types if t != NodeType.GRAD]
|
|
|
|
output_node_type = non_grad_types[0]
|
|
for non_grad_type in non_grad_types:
|
|
assert (
|
|
non_grad_type == output_node_type
|
|
), f"Found more than one non grad types! Expect {output_node_type} but found {non_grad_type}!"
|
|
assert output_node_type in [
|
|
NodeType.PARAM,
|
|
NodeType.STATE,
|
|
], f"Expecting output node type to be either state or param, but found {output_node_type}!"
|
|
|
|
dp_strategy_map[node] = DataParallelStrategy(
|
|
output_node_type, [replica_strategy, shard_strategy]
|
|
)
|
|
elif NodeType.STATE in input_node_types:
|
|
# either param + state or state + state
|
|
replica_strategy = _gen_replicate_strategy(mesh)
|
|
shard_strategy = _gen_shard_strategy(mesh, shard_dim=0)
|
|
output_node_type = (
|
|
NodeType.PARAM
|
|
if NodeType.PARAM in input_node_types
|
|
else NodeType.STATE
|
|
)
|
|
|
|
dp_strategy_map[node] = DataParallelStrategy(
|
|
output_node_type, [replica_strategy, shard_strategy]
|
|
)
|
|
elif NodeType.PARAM in input_node_types:
|
|
if NodeType.ACT in input_node_types:
|
|
# param + activation, build up acceptable strategy
|
|
# param must be replicated, activation must be sharded
|
|
for arg in input_args:
|
|
arg_strategy = dp_strategy_map[arg]
|
|
assert isinstance(arg_strategy, DataParallelStrategy)
|
|
node_type = arg_strategy.node_type
|
|
if node_type == NodeType.ACT:
|
|
# compute activation spec
|
|
act_spec = batch_dim_analyzer.compute_act_spec(
|
|
arg, mesh
|
|
)
|
|
|
|
input_specs.append(act_spec)
|
|
elif node_type == NodeType.PARAM:
|
|
# param must be replicated
|
|
input_specs.append(
|
|
DTensorSpec(mesh=mesh, placements=(Replicate(),))
|
|
)
|
|
else:
|
|
raise RuntimeError(
|
|
f"Expecting node with parameter and activation, but found {input_node_types}! "
|
|
)
|
|
# produce activation type sharding for output
|
|
output_spec = batch_dim_analyzer.compute_act_spec(node, mesh)
|
|
|
|
act_strategy = PlacementStrategy(
|
|
output_spec=output_spec, input_specs=input_specs
|
|
)
|
|
|
|
dp_strategy_map[node] = DataParallelStrategy(
|
|
NodeType.ACT, [act_strategy]
|
|
)
|
|
else:
|
|
# If inputs only have parameters, the
|
|
# strategy of this node should follow input
|
|
dp_strategy_map[node] = dp_strategy_map[input_args[0]]
|
|
else:
|
|
# If input nodes does not have PARAM/GRAD/STATE, then
|
|
# it should be a pure activation computation, it should
|
|
# produce activation output.
|
|
# Activations are usually sharded unless model creates
|
|
# new tensors during computation, which depend on whether
|
|
# the new tensor associate with a batch dim or not, it could
|
|
# be shard/replicate/partial, batch dim analyzer should tell
|
|
# us the correct sharding.
|
|
for arg in input_args:
|
|
arg_strategy = dp_strategy_map[arg]
|
|
assert isinstance(arg_strategy, DataParallelStrategy)
|
|
input_spec = batch_dim_analyzer.compute_act_spec(arg, mesh)
|
|
|
|
input_specs.append(input_spec)
|
|
|
|
act_spec = batch_dim_analyzer.compute_act_spec(node, mesh)
|
|
op_strategy = PlacementStrategy(
|
|
output_spec=act_spec, input_specs=input_specs
|
|
)
|
|
dp_strategy_map[node] = DataParallelStrategy(
|
|
NodeType.ACT, [op_strategy]
|
|
)
|
|
|
|
elif node.op == "output":
|
|
dp_strategy_map[node] = DataParallelStrategy(NodeType.NON_TENSOR, [])
|
|
else:
|
|
raise RuntimeError(f"op code {node.op} not supported")
|
|
|
|
return dp_strategy_map # type: ignore[return-value]
|
|
|
|
|
|
def mark_data_parallel_shardings(
|
|
train_step_graph: GraphModule,
|
|
num_parameters: int,
|
|
num_states: int,
|
|
dp_strategy_map: Dict[fx.Node, StrategyType],
|
|
parallel_mode: DataParallelStyle = DataParallelStyle.FULLY_SHARD,
|
|
) -> None:
|
|
"""
|
|
This function marks the sharding for the nodes in the train_step_graph
|
|
"""
|
|
activation_idx = num_parameters + num_states
|
|
placeholder_idx = 0
|
|
for node in train_step_graph.graph.nodes:
|
|
node_strategy = dp_strategy_map[node]
|
|
if node.op == "placeholder":
|
|
assert isinstance(node_strategy, DataParallelStrategy)
|
|
node_type = node_strategy.node_type
|
|
node_strategies = node_strategy.strategies
|
|
if node_type == NodeType.NON_TENSOR:
|
|
# set node sharding to None
|
|
node_sharding = None
|
|
elif placeholder_idx < activation_idx:
|
|
assert len(node_strategies) > 0, "node_strategies should not be empty"
|
|
if parallel_mode == DataParallelStyle.REPLICATE:
|
|
# set to replicate for replicate style
|
|
node_sharding = node_strategies[0]
|
|
elif parallel_mode == DataParallelStyle.FULLY_SHARD:
|
|
# set to shard for fully shard style
|
|
if len(node_strategies) == 1:
|
|
# only one strategy, use that instead
|
|
# i.e. optimizer state steps can only be replicate
|
|
node_sharding = node_strategies[0]
|
|
else:
|
|
# use the full sharding strategy
|
|
node_sharding = node_strategies[1]
|
|
elif parallel_mode == DataParallelStyle.DEFAULT:
|
|
# TODO: add support for default mode
|
|
# default mode would generate either replicate or shard
|
|
raise NotImplementedError("default mode not implemented")
|
|
else:
|
|
assert len(node_strategies) > 0, "node_strategies should not be empty"
|
|
# mark activation as sharded on batch dim
|
|
node_sharding = node_strategies[0]
|
|
|
|
node.meta["sharding"] = node_sharding
|
|
|
|
placeholder_idx += 1
|
|
elif node.op == "call_function":
|
|
if isinstance(node_strategy, TupleStrategy):
|
|
# For tuple strategy in the data parallel mode, it should have the same strategy
|
|
# for all tuple elements, assert that then use the first element's strategy as sharding
|
|
first_strategy = cast(DataParallelStrategy, node_strategy.childs[0])
|
|
for child_strategy in node_strategy.childs:
|
|
assert isinstance(child_strategy, DataParallelStrategy)
|
|
assert child_strategy.strategies == first_strategy.strategies
|
|
|
|
node_strategies = first_strategy.strategies
|
|
else:
|
|
assert isinstance(node_strategy, DataParallelStrategy)
|
|
node_strategies = node_strategy.strategies
|
|
|
|
assert (
|
|
len(node_strategies) <= 2
|
|
), "data parallel should have at most 2 strategies"
|
|
if len(node_strategies) == 1:
|
|
node.meta["sharding"] = node_strategies[0]
|
|
elif len(node_strategies) == 2:
|
|
if parallel_mode == DataParallelStyle.REPLICATE:
|
|
# set to replicate for replicate style
|
|
node.meta["sharding"] = node_strategies[0]
|
|
elif parallel_mode == DataParallelStyle.FULLY_SHARD:
|
|
# set to shard for fully shard style
|
|
node.meta["sharding"] = node_strategies[1]
|
|
else:
|
|
raise RuntimeError("default mode not supported yet!")
|
|
else:
|
|
raise RuntimeError(
|
|
f"node {node} strategy length {len(node_strategies)} is not expected!"
|
|
)
|
|
elif node.op == "output":
|
|
assert (
|
|
isinstance(node_strategy, DataParallelStrategy)
|
|
and node_strategy.node_type == NodeType.NON_TENSOR
|
|
), "output node should not be tensor"
|
|
node.meta["sharding"] = None
|
|
else:
|
|
raise RuntimeError(f"op code {node.op} not supported")
|
|
|
|
|
|
def _partition_val(val: Any, spec: DTensorSpec) -> Any:
|
|
"""
|
|
util function to convert a full tensor val to its local component
|
|
"""
|
|
if isinstance(val, torch.Tensor):
|
|
local_shard = val
|
|
if val.ndim == 0:
|
|
# If it's already a scalar tensor, it is already local, we don't
|
|
# need to do anything
|
|
return local_shard
|
|
|
|
for idx, placement in enumerate(spec.placements):
|
|
if placement.is_shard():
|
|
placement = cast(Shard, placement)
|
|
num_chunks = spec.mesh.size(dim=idx)
|
|
my_coord = spec.mesh.get_coordinate()
|
|
assert my_coord is not None, "current rank not in mesh!"
|
|
my_coord_on_mesh_dim = my_coord[idx]
|
|
local_shard = placement._split_tensor(
|
|
local_shard, num_chunks, with_padding=False, contiguous=False
|
|
)[0][my_coord_on_mesh_dim]
|
|
return local_shard
|
|
elif isinstance(val, (tuple, list)):
|
|
return val.__class__(_partition_val(v, spec) for v in val)
|
|
else:
|
|
raise RuntimeError(f"val type {type(val)} not supported")
|
|
|
|
|
|
def partitioner(graph: GraphModule) -> GraphModule:
|
|
"""
|
|
Graph partitioner that partitions the single device graph
|
|
to distributed graph
|
|
"""
|
|
shape_adjustment_ops = {
|
|
aten._unsafe_view.default: 1,
|
|
aten.expand.default: 1,
|
|
aten.new_zeros.default: 1,
|
|
aten.ones.default: 0,
|
|
aten.reshape.default: 1,
|
|
aten.view.default: 1,
|
|
aten.zeros.default: 0,
|
|
}
|
|
# partition the graph to distributed
|
|
for node in graph.graph.nodes:
|
|
node_sharding = node.meta["sharding"]
|
|
# None sharding means this node don't need sharding
|
|
if node_sharding is None:
|
|
continue
|
|
|
|
if node.op == "placeholder":
|
|
out_spec = node_sharding.output_spec
|
|
if not hasattr(out_spec, "from_local"):
|
|
local_val = _partition_val(node.meta["val"], out_spec)
|
|
# update node value
|
|
node.meta["val"] = local_val
|
|
elif node.op == "call_function":
|
|
out_spec = node_sharding.output_spec
|
|
|
|
# check if there's misaligned sharding, insert reshard if there is
|
|
expected_input_specs = node_sharding.input_specs
|
|
for idx, input_arg in enumerate(node.all_input_nodes):
|
|
input_arg_sharding = input_arg.meta["sharding"]
|
|
|
|
input_arg_spec = input_arg_sharding.output_spec
|
|
desired_spec = (
|
|
out_spec
|
|
if expected_input_specs is None
|
|
else expected_input_specs[idx]
|
|
)
|
|
if input_arg_spec != desired_spec:
|
|
input_arg_spec.tensor_meta = input_arg.meta["tensor_meta"]
|
|
desired_spec.tensor_meta = input_arg.meta["tensor_meta"]
|
|
input_arg_tensor = input_arg.meta["val"]
|
|
|
|
# insert reshard operation
|
|
def reshard_fn(local_tensor: torch.Tensor) -> torch.Tensor:
|
|
return redistribute_local_tensor(
|
|
local_tensor,
|
|
input_arg_spec,
|
|
desired_spec,
|
|
)
|
|
|
|
reshard_gm = make_fx(reshard_fn)(input_arg_tensor)
|
|
reshard_gm_nodes = list(reshard_gm.graph.nodes)
|
|
input_node = reshard_gm_nodes[0]
|
|
with graph.graph.inserting_before(node):
|
|
output_node = graph.graph.graph_copy(
|
|
reshard_gm.graph,
|
|
val_map={
|
|
input_node: input_arg,
|
|
},
|
|
)
|
|
node.replace_input_with(input_arg, output_node)
|
|
|
|
output_val = node.meta["val"]
|
|
|
|
if node.target == torch.ops.aten.repeat.default:
|
|
# for repeat op, we need to infer the repeat sizes
|
|
assert isinstance(output_val, torch.Tensor)
|
|
local_shape = compute_local_shape(
|
|
output_val.shape, out_spec.mesh, out_spec.placements
|
|
)
|
|
input_shape = node.args[0].meta["val"].shape
|
|
|
|
def infer_repeat_sizes(repeated_shape, input_shape):
|
|
repeated_size = [1] * len(repeated_shape)
|
|
padded_length = len(repeated_shape) - len(input_shape)
|
|
for i in range(len(repeated_shape)):
|
|
if i < padded_length:
|
|
repeated_size[i] = repeated_shape[i]
|
|
else:
|
|
repeated_size[i] = (
|
|
repeated_shape[i] // input_shape[i - padded_length]
|
|
)
|
|
|
|
return repeated_size
|
|
|
|
node.update_arg(1, infer_repeat_sizes(local_shape, input_shape))
|
|
|
|
elif node.target in shape_adjustment_ops:
|
|
# for view related op that needs shape, adjust shape to local shape if needed
|
|
assert isinstance(output_val, torch.Tensor)
|
|
local_shape = compute_local_shape(
|
|
output_val.shape, out_spec.mesh, out_spec.placements
|
|
)
|
|
shape_arg_num = shape_adjustment_ops[node.target]
|
|
node.update_arg(shape_arg_num, local_shape)
|
|
|
|
# convert output val to its local component
|
|
node.meta["val"] = _partition_val(output_val, out_spec)
|
|
|
|
elif node.op == "output":
|
|
break
|
|
else:
|
|
raise RuntimeError(f"op code {node} not supported")
|
|
|
|
# clean up the graph by removing sharding and partitioning related metadata
|
|
for node in graph.graph.nodes:
|
|
if "sharding" in node.meta:
|
|
del node.meta["sharding"]
|
|
if "val" in node.meta and isinstance(node.meta["val"], torch.Tensor):
|
|
local_tensor_meta = _extract_tensor_metadata(node.meta["val"])
|
|
node.meta["tensor_meta"] = local_tensor_meta
|
|
|
|
graph.graph.lint()
|
|
graph.recompile()
|
|
return graph
|
|
|
|
|
|
def partition_data_parallel(
|
|
graph: GraphModule,
|
|
model: nn.Module,
|
|
optimizer: Optional[torch.optim.Optimizer],
|
|
params_buffers: Dict[str, torch.Tensor],
|
|
named_states: Dict[str, Any],
|
|
args: Tuple[Any, ...],
|
|
kwargs: Dict[str, Any],
|
|
mesh: DeviceMesh,
|
|
parallel_style: DataParallelStyle,
|
|
input_batch_dim: int,
|
|
) -> GraphModule:
|
|
"""
|
|
The entry point function to partition the graph to data parallel
|
|
graph, it also shard/replicate the model parameters and optimizer
|
|
states to DTensors.
|
|
"""
|
|
num_params_buffers = len(params_buffers)
|
|
flattened_states = pytree.tree_flatten(named_states)[0]
|
|
num_states = len(flattened_states)
|
|
|
|
changed = graph.graph.eliminate_dead_code()
|
|
if changed:
|
|
graph.recompile()
|
|
|
|
# 1. First build up data parallel strategies for the whole graph
|
|
strategy_map = build_data_parallel_strategies(
|
|
graph, num_params_buffers, num_states, mesh=mesh, batch_dim=input_batch_dim
|
|
)
|
|
|
|
# 2. Next we mark the data parallel strategy for each node base on
|
|
# the parallel_style
|
|
mark_data_parallel_shardings(
|
|
graph,
|
|
num_parameters=num_params_buffers,
|
|
num_states=num_states,
|
|
dp_strategy_map=strategy_map,
|
|
parallel_mode=parallel_style,
|
|
)
|
|
|
|
# 3. Partition the single machine graph to the distribute graph
|
|
partitioned_graph = partitioner(graph)
|
|
|
|
# preserve node types for the expanded graph
|
|
for node in partitioned_graph.graph.nodes:
|
|
if node in strategy_map:
|
|
node_strategy = strategy_map[node]
|
|
if isinstance(node_strategy, DataParallelStrategy):
|
|
node.meta["node_type"] = node_strategy.node_type
|
|
elif isinstance(node_strategy, TupleStrategy):
|
|
node.meta["node_type"] = NodeType.NON_TENSOR
|
|
else:
|
|
raise RuntimeError(f"Unknown node strategy {node_strategy}")
|
|
else:
|
|
# if the nodes are expanded nodes (collectives), we mark them
|
|
# the same type as the input node.
|
|
input_node = node.all_input_nodes[0]
|
|
node.meta["node_type"] = input_node.meta["node_type"]
|
|
|
|
# 4. Last, inplace partition the weights and optim states to
|
|
# DTensors base on the parallel style
|
|
accessor = NamedMemberAccessor(model)
|
|
for param_key, param in params_buffers.items():
|
|
placement: Placement = Replicate()
|
|
if parallel_style == DataParallelStyle.FULLY_SHARD:
|
|
placement = Shard(0)
|
|
elif parallel_style != DataParallelStyle.REPLICATE:
|
|
raise RuntimeError(f"parallel style {parallel_style} not supported yet")
|
|
|
|
dtensor_param = distribute_tensor(param, mesh, [placement])
|
|
# update re-parameterized module param dict and optim states dict to DTensor
|
|
params_buffers[param_key] = dtensor_param.to_local()
|
|
# update module parameters to DTensor
|
|
accessor.set_tensor(param_key, dtensor_param)
|
|
|
|
# update the optimizer state key and values to DTensor
|
|
if optimizer is not None and param in optimizer.state:
|
|
param_states = named_states[param_key]
|
|
param_dtensor_states = {}
|
|
for state_key, state_val in param_states.items():
|
|
if isinstance(state_val, torch.Tensor) and state_val.ndim > 0:
|
|
# shard/replicate non-scalar tensors, for scalar tensor, we
|
|
# don't do anything
|
|
dtensor_state = distribute_tensor(state_val, mesh, [placement])
|
|
param_dtensor_states[state_key] = dtensor_state
|
|
param_states[state_key] = dtensor_state.to_local()
|
|
else:
|
|
param_dtensor_states[state_key] = state_val
|
|
|
|
optimizer.state.pop(param) # type: ignore[call-overload]
|
|
optimizer.state[dtensor_param] = param_dtensor_states # type: ignore[index]
|
|
|
|
return partitioned_graph
|