[dtensor] tensor ops to use strategy based sharding prop (#100607)

This is the first series of PR that adopts operator impls to use a
strategy based approach, each op utilizes OpStrategy and PlacementStrategy
to generate their own strategy. By utilizing the strategy based
approach along with the op graph, we could enable more advanced op
implementation (decomp is possible), and turn the sharding prop to be
more like a contraint satisfication problem.

This PR alone only adds some basic tensor op strategies, and it directly
works on the op graph that was used for metadata propagation. The tensor ops
added in this PR mainly follows one of the arg strategy. The next set of
PRs would add more op strategies to other ops.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100607
Approved by: https://github.com/XilunWu
This commit is contained in:
Wanchao Liang 2023-05-10 23:45:15 +00:00 committed by PyTorch MergeBot
parent d1f0c8e2d0
commit a1aa32e204
9 changed files with 266 additions and 118 deletions

View File

@ -145,7 +145,7 @@ def _operator_dispatch(
# unwrap the args/kwargs schema
op_schema = sharding_propagator.prepare_op_schema(op_call, args, kwargs)
output_sharding = sharding_propagator.propagate_op_sharding(op_call, op_schema)
output_sharding = sharding_propagator.propagate(op_call, op_schema)
# first we need to lift some private aten aliases to public calls
if op_call in _CURRENT_DECOMPOSITION_TABLE:

View File

@ -76,6 +76,12 @@ class OpStrategy(StrategyType):
strategy_list_str = ", ".join([str(strategy) for strategy in self.strategies])
return f"OpStrategy: [{strategy_list_str}]"
def max_num_shards(self) -> int:
"""
Returns the max number of shards across all placement strategies
"""
return max([strategy.output_spec.num_shards for strategy in self.strategies])
class TupleStrategy(StrategyType):
"""
@ -204,6 +210,19 @@ class OpSchema:
DTensorSpec, _rebuild_tensor_from_dtensor_meta, self.kwargs_schema
)
def _inplace_rewrap_schema_suggestion(self, origin_schema: "OpSchema") -> None:
suggestion_args_spec = self.args_spec
new_arg_schema: List[object] = []
idx_of_args_spec = 0
for arg in origin_schema.args_schema:
if isinstance(arg, DTensorSpec):
new_arg_schema.append(suggestion_args_spec[idx_of_args_spec])
idx_of_args_spec += 1
else:
new_arg_schema.append(arg)
self.args_schema = tuple(new_arg_schema)
self.kwargs_schema = origin_schema.kwargs_schema
@dataclass
class OutputSharding:

View File

@ -13,22 +13,6 @@ def _replace_char_in_str(string: str, new_char: str, idx: int) -> str:
return string[:idx] + new_char + string[idx + 1 :]
def _inplace_rewrap_schema_suggestion(
suggestion: OpSchema, input_schema: OpSchema
) -> None:
suggestion_args_spec = suggestion.args_spec
new_arg_schema: List[object] = []
idx_of_args_spec = 0
for arg in input_schema.args_schema:
if isinstance(arg, DTensorSpec):
new_arg_schema.append(suggestion_args_spec[idx_of_args_spec])
idx_of_args_spec += 1
else:
new_arg_schema.append(arg)
suggestion.args_schema = tuple(new_arg_schema)
suggestion.kwargs_schema = input_schema.kwargs_schema
def _gen_reshard_suggestions(
op_schema: OpSchema,
input_dims: List[str],
@ -48,7 +32,7 @@ def _gen_reshard_suggestions(
)
)
suggested_schema = OpSchema(op_schema.func_schema, tuple(suggested_arg_specs), {})
_inplace_rewrap_schema_suggestion(suggested_schema, op_schema)
suggested_schema._inplace_rewrap_schema_suggestion(op_schema)
return OutputSharding(
None,
schema_suggestions=[suggested_schema],
@ -350,7 +334,7 @@ def reduction_rule(
input_spec.mesh, reshard_dim_map, [], tensor_meta=input_spec.tensor_meta
)
schema_suggestion = OpSchema(op_schema.func_schema, (no_partial_spec,), {})
_inplace_rewrap_schema_suggestion(schema_suggestion, op_schema)
schema_suggestion._inplace_rewrap_schema_suggestion(op_schema)
return OutputSharding(
output_spec=None, schema_suggestions=[schema_suggestion]
)

View File

@ -1,4 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import torch
from torch.distributed._tensor.ops.common_rules import (

View File

@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import cast, List, Optional, Sequence, Tuple
from typing import cast, Dict, List, Optional, Sequence, Tuple
import torch
@ -11,38 +11,92 @@ from torch.distributed._tensor.api import (
Replicate,
Shard,
)
from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
from torch.distributed._tensor.device_mesh import DeviceMesh
from torch.distributed._tensor.op_schema import (
OpSchema,
OpStrategy,
OutputSharding,
PlacementStrategy,
StrategyType,
)
from torch.distributed._tensor.ops.common_rules import pointwise_rule
from torch.distributed._tensor.ops.utils import normalize_dim, prod, register_prop_rule
from torch.distributed._tensor.ops.utils import (
normalize_dim,
prod,
register_op_strategy,
register_prop_rule,
)
from torch.fx import Node
aten = torch.ops.aten
# NOTE: the default propagation rule should apply for
# any operator that does not return a DTensor, i.e.
# for operators that only returns int/float/bool, we by
# default still propagate the spec, this is to ensure
# that we only return None for the case where the sharding
# propagation failed, and we should do auto-redistribute
def default_prop_rule(op_schema: OpSchema) -> OutputSharding:
# by default prop the first arg spec
return OutputSharding(op_schema.args_spec[0])
def prop_create_like(op_schema: OpSchema) -> OutputSharding:
# For operators that create tensors with same shape as input but
# with specific content that does not depend on the input, we
# can propagate Sharding, but we have to make sure we move from
# partial to replicated.
input_spec = op_schema.args_spec[0]
output_spec = DTensorSpec(
mesh=input_spec.mesh,
placements=tuple(
Replicate() if isinstance(p, _Partial) else p for p in input_spec.placements
),
@register_op_strategy(
[
aten._to_copy.default,
aten.clone.default,
aten.contiguous.default,
aten.copy_.default,
aten.detach.default,
aten.new_empty_strided.default, # TODO: re-think new_empty_strided
]
)
def default_strategy(
node: Node, mesh: DeviceMesh, node_to_strategy: Dict[Node, StrategyType]
) -> StrategyType:
# Default strategy by default just propagate the first input strategy
select_strategy = node_to_strategy[node.all_input_nodes[0]]
assert isinstance(select_strategy, OpStrategy)
return OpStrategy(
[
PlacementStrategy(arg_strategy.output_spec)
for arg_strategy in select_strategy.strategies
]
)
return OutputSharding(output_spec=output_spec)
@register_op_strategy(
[
aten.empty_like.default,
aten.fill_.Scalar,
aten.full_like.default,
aten.ones_like.default,
aten.zero_.default,
aten.zeros_like.default,
]
)
def create_like_strategy(
node: Node, mesh: DeviceMesh, node_to_strategy: Dict[Node, StrategyType]
) -> StrategyType:
# create_like_strategy deals with ops that creating tensors with same
# shape as input, but with specific content that does not depend on
# the input, we can propagate sharding, but we have to make sure we
# move from partial to replicated.
select_strategy = node_to_strategy[node.all_input_nodes[0]]
create_like_strategy = OpStrategy([])
assert isinstance(select_strategy, OpStrategy)
for arg_strategy in select_strategy.strategies:
arg_spec = arg_strategy.output_spec
if arg_spec.sums:
# if the arg_spec have partial, accept partial
# in the input_specs but output replicate for
# those corresponding mesh dims
output_spec = DTensorSpec(
mesh=arg_spec.mesh,
placements=tuple(
Replicate() if isinstance(p, _Partial) else p
for p in arg_spec.placements
),
)
create_like_strategy.strategies.append(
PlacementStrategy(output_spec=output_spec, input_specs=(arg_spec,))
)
else:
create_like_strategy.strategies.append(PlacementStrategy(arg_spec))
return create_like_strategy
@register_prop_rule(aten._local_scalar_dense.default)
@ -85,36 +139,12 @@ def non_tensor_prop_rule(op_schema: OpSchema) -> OutputSharding:
return OutputSharding(output_spec=None)
default_prop_ops = [
aten._to_copy.default,
aten.clone.default,
aten.contiguous.default,
aten.copy_.default,
aten.detach.default,
aten.new_empty_strided.default,
]
create_like_ops = [
aten.empty_like.default,
aten.fill_.Scalar,
aten.full_like.default,
aten.ones_like.default,
aten.zero_.default,
aten.zeros_like.default,
]
new_factory_ops = [
aten.new_full.default,
aten.new_ones.default,
aten.new_zeros.default,
]
for op in default_prop_ops:
register_prop_rule(op)(default_prop_rule)
for op in create_like_ops:
register_prop_rule(op)(prop_create_like)
for op in new_factory_ops:
register_prop_rule(op)(new_factory_rule)

View File

@ -7,15 +7,6 @@ import torch
from torch.distributed._tensor.api import DTensor
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def unwrap_single_placement(e):
if not isinstance(e, DTensor):
return None
assert len(e.placements) == 1, "more than one placement!"
return e.placements[0]
# convenient wrapper to register sharding propagation rules
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
@ -32,6 +23,19 @@ def register_prop_rule(op):
return wrapper
def register_op_strategy(op):
# pyre-fixme[53]: Captured variable `func` is not annotated.
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def wrapper(impl):
overloads = op if isinstance(op, list) else [op]
for overload in overloads:
DTensor._propagator.register_op_strategy(overload, impl)
return impl
return wrapper
def as_list(
x: Union[List[object], object]
# pyre-fixme[11]: Annotation `immutable_list` is not defined as a type.

View File

@ -412,6 +412,14 @@ class DTensorSpec:
raise ValueError("tensor_meta is not set")
return len(self.tensor_meta.shape)
@property
def num_shards(self) -> int:
num_shards = 1
for i, placement in enumerate(self.placements):
if placement.is_shard():
num_shards *= self.mesh.size(i)
return num_shards
@property
def dim_map(self) -> List[int]:
"""

View File

@ -4,9 +4,19 @@ import torch
import torch.distributed._tensor.api as dtensor
from torch._ops import OpOverload
from torch._subclasses import FakeTensorMode
from torch.distributed._tensor.op_schema import DTensorSpec, OpSchema, OutputSharding
from torch.distributed._tensor.device_mesh import DeviceMesh
from torch.distributed._tensor.op_schema import (
DTensorSpec,
OpSchema,
OpStrategy,
OutputSharding,
OutputSpecType,
PlacementStrategy,
StrategyType,
)
from torch.fx import Node
from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule
from torch.utils._pytree import tree_map
from torch.utils._pytree import tree_flatten, tree_map
"""
Print information on ops input shape and sharding for debugging purposes.
@ -21,6 +31,10 @@ def unwrap_schema(e: object) -> object:
class ShardingPropagator:
def __init__(self) -> None:
self.op_to_rules: Dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {}
self.op_strategy_funcs: Dict[
OpOverload,
Callable[[Node, DeviceMesh, Dict[Node, StrategyType]], StrategyType],
] = {}
def register_sharding_prop_rule(
self, op_overload: OpOverload, rule_func: Callable[[OpSchema], OutputSharding]
@ -30,6 +44,16 @@ class ShardingPropagator:
"""
self.op_to_rules[op_overload] = rule_func
def register_op_strategy(
self,
op_overload: OpOverload,
rule_func: Callable[[Node, DeviceMesh, Dict[Node, StrategyType]], StrategyType],
):
"""
Register a sharding strategy generator for an operator.
"""
self.op_strategy_funcs[op_overload] = rule_func
def prepare_op_schema(
self, op_call: OpOverload, args: Tuple[object, ...], kwargs: Dict[str, object]
) -> OpSchema:
@ -54,6 +78,103 @@ class ShardingPropagator:
return op_schema
def propagate(self, op_overload: OpOverload, op_schema: OpSchema) -> OutputSharding:
if op_overload in self.op_strategy_funcs:
# generate op strategy for the op, this is done by propagating
# the sharding in the graph.
op_gm = self._prepare_op_graph(op_overload, op_schema)
if op_gm is None:
return OutputSharding(None, [op_schema])
flat_args_sharding, _ = tree_flatten(
[op_schema.args_schema, op_schema.kwargs_schema]
)
node_to_strategy: Dict[Node, StrategyType] = {}
output_node = None
out_node_strategy = None
mesh = flat_args_sharding[0].mesh
placeholder_idx = 0
for node in op_gm.graph.nodes:
if node.op == "placeholder":
# set sharding to placeholders if it's Node
if isinstance(flat_args_sharding[placeholder_idx], DTensorSpec):
strategy = PlacementStrategy(
flat_args_sharding[placeholder_idx]
)
# for eager execution, inputs only have one fixed sharding
node_to_strategy[node] = OpStrategy([strategy])
placeholder_idx += 1
elif node.op == "call_function":
if isinstance(node.target, OpOverload):
op_strategy_func = self.op_strategy_funcs[op_overload]
out_strategies = op_strategy_func(node, mesh, node_to_strategy)
node_to_strategy[node] = out_strategies
else:
raise NotImplementedError(
f"Unsupported function: {node.target}"
)
elif node.op == "output":
output_node = node.args[0]
out_node_strategy = node_to_strategy[output_node[0]]
else:
raise NotImplementedError(f"Unsupported node type: {node.op}")
# NOTE: This had the assumption we only have one call_function op in the
# op graph, we need to harden this logic when there're decomposed ops.
assert isinstance(out_node_strategy, OpStrategy)
# we take the first strategy for now
# TODO: add a min cost selection logic
output_strategy = out_node_strategy.strategies[0]
needs_redistribute = False
expected_input_specs = []
for idx, input_spec in enumerate(op_schema.args_spec):
desired_spec = (
output_strategy.output_spec
if output_strategy.input_specs is None
else output_strategy.input_specs[idx]
)
expected_input_specs.append(desired_spec)
if input_spec != desired_spec:
needs_redistribute = True
if needs_redistribute:
suggestion_schema = OpSchema(
op_schema.func_schema, tuple(expected_input_specs), {}
)
suggestion_schema._inplace_rewrap_schema_suggestion(op_schema)
else:
suggestion_schema = op_schema
output_sharding = OutputSharding(
output_strategy.output_spec,
[suggestion_schema],
)
if output_node is not None:
self._wrap_output_spec_meta(output_sharding.output_spec, output_node)
return output_sharding
elif op_overload in self.op_to_rules:
return self.propagate_op_sharding(op_overload, op_schema)
else:
raise NotImplementedError(
f"Operator {op_overload} does not have a sharding strategy registered."
)
def _wrap_output_spec_meta(
self, output_spec: OutputSpecType, output_nodes: Node
) -> None:
"""
Wrap the output_spec with the metadata from the output node.
"""
if output_spec is not None:
assert isinstance(output_nodes, (tuple, list))
if isinstance(output_spec, DTensorSpec):
output_spec.tensor_meta = output_nodes[0].meta["tensor_meta"]
elif isinstance(output_spec, (tuple, list)):
for i, spec in enumerate(output_spec):
if isinstance(spec, DTensorSpec):
spec.tensor_meta = output_nodes[i].meta["tensor_meta"]
def propagate_op_sharding(
self, op_overload: OpOverload, op_schema: OpSchema
) -> OutputSharding:
@ -61,19 +182,17 @@ class ShardingPropagator:
Propagate the sharding for an operator given the op_schema.
"""
# first we propagate the tensor metadata
output_node = self._propagate_tensor_meta(op_overload, op_schema)
output_node = None
op_gm = self._prepare_op_graph(op_overload, op_schema)
if op_gm is not None:
for node in op_gm.graph.nodes:
if node.op == "output":
output_node = node.args[0]
# then we propagate the sharding
sharding_prop_func = self.op_to_rules.get(op_overload, None)
sharding_prop_func = self.op_to_rules[op_overload]
if sharding_prop_func is None:
# step 1. If there's not even one sharding rule
# implemented for the operator, we error out.
raise NotImplementedError(
f"Operator {op_overload} does not have a DistributedTensor rule registered."
)
# step 2. there's sharding propagation rule, run
# step 1. there's sharding propagation rule, run
# sharding propagation to get the output sharding
try:
output_sharding = sharding_prop_func(op_schema)
@ -86,7 +205,7 @@ class ShardingPropagator:
f"Error: {e}"
) from e
# step 3. if can't get output_spec from sharding
# step 2. if can't get output_spec from sharding
# propagation (i.e. no rules apply for input
# placements), we return the output sharding
# with schema suggestions, which can be used to
@ -110,8 +229,6 @@ class ShardingPropagator:
# to get an eligible input, which we will pick a
# schema suggestion base on the redistribute cost.
# For now we simply pick the first suggestion.
# TODO: implement full auto distribute with a
# simple cost estimation model
suggested_input_schema = output_sharding.schema_suggestions[0]
# run sharding propagation again with suggested schema
propagation_res = sharding_prop_func(suggested_input_schema)
@ -126,24 +243,15 @@ class ShardingPropagator:
# associate the output sharding with the output metadata
if output_node is not None:
output_nodes = output_node.args[0]
output_spec = output_sharding.output_spec
if output_spec is not None:
assert isinstance(output_nodes, (tuple, list))
if isinstance(output_spec, DTensorSpec):
output_spec.tensor_meta = output_nodes[0].meta["tensor_meta"]
elif isinstance(output_spec, (tuple, list)):
for i, spec in enumerate(output_spec):
if isinstance(spec, DTensorSpec):
spec.tensor_meta = output_nodes[i].meta["tensor_meta"]
self._wrap_output_spec_meta(output_sharding.output_spec, output_node)
return output_sharding
def _propagate_tensor_meta(
def _prepare_op_graph(
self,
op_overload: OpOverload,
op_schema: OpSchema,
) -> Optional[torch.fx.Node]:
) -> Optional[torch.fx.GraphModule]:
# right now we only use the graph for metadata prop, but next we will use
# the graph to do sharding prop together
@ -163,11 +271,7 @@ class ShardingPropagator:
fake_kwargs = op_schema.gen_fake_kwargs()
g = get_isolated_graphmodule(op_overload, fake_args, fake_kwargs)
output = None
for node in g.graph.nodes:
if node.op == "output":
output = node
return output
return g
class _CachingPropagator(ShardingPropagator):
@ -176,18 +280,16 @@ class _CachingPropagator(ShardingPropagator):
This is currently experimental for Tensor Parallel usage.
"""
def __init__(self, op_to_rules=None) -> None:
def __init__(self, propagator: ShardingPropagator) -> None:
super().__init__()
if op_to_rules is not None:
self.op_to_rules = op_to_rules
self.op_to_rules = propagator.op_to_rules
self.op_strategy_funcs = propagator.op_strategy_funcs
# cache table for sharding propagation results, we might need to
# limit the size of the cache table in the future
self.cached_prop_results: Dict[OpSchema, OutputSharding] = {}
def propagate_op_sharding(
self, op_overload: OpOverload, op_schema: OpSchema
) -> OutputSharding:
def propagate(self, op_overload: OpOverload, op_schema: OpSchema) -> OutputSharding:
"""
Propagate the sharding for an operator given the op_schema.
Cache the propagation results to avoid running propagation again.
@ -196,7 +298,7 @@ class _CachingPropagator(ShardingPropagator):
return self.cached_prop_results[op_schema]
else:
# call DTensor's propagate_op_sharding to get the prop result
output_sharding = super().propagate_op_sharding(op_overload, op_schema)
output_sharding = super().propagate(op_overload, op_schema)
# update cached table
self.cached_prop_results[op_schema] = output_sharding
return output_sharding

View File

@ -30,7 +30,7 @@ __all__ = [
# switch the DTensor propagator to use the caching propagator to speed up
# the TP eager execution time.
DTensor._propagator = _CachingPropagator(DTensor._propagator.op_to_rules)
DTensor._propagator = _CachingPropagator(DTensor._propagator)
def parallelize_module( # type: ignore[return]
module: nn.Module,