mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dtensor] tensor ops to use strategy based sharding prop (#100607)
This is the first series of PR that adopts operator impls to use a strategy based approach, each op utilizes OpStrategy and PlacementStrategy to generate their own strategy. By utilizing the strategy based approach along with the op graph, we could enable more advanced op implementation (decomp is possible), and turn the sharding prop to be more like a contraint satisfication problem. This PR alone only adds some basic tensor op strategies, and it directly works on the op graph that was used for metadata propagation. The tensor ops added in this PR mainly follows one of the arg strategy. The next set of PRs would add more op strategies to other ops. Pull Request resolved: https://github.com/pytorch/pytorch/pull/100607 Approved by: https://github.com/XilunWu
This commit is contained in:
parent
d1f0c8e2d0
commit
a1aa32e204
|
|
@ -145,7 +145,7 @@ def _operator_dispatch(
|
|||
# unwrap the args/kwargs schema
|
||||
op_schema = sharding_propagator.prepare_op_schema(op_call, args, kwargs)
|
||||
|
||||
output_sharding = sharding_propagator.propagate_op_sharding(op_call, op_schema)
|
||||
output_sharding = sharding_propagator.propagate(op_call, op_schema)
|
||||
|
||||
# first we need to lift some private aten aliases to public calls
|
||||
if op_call in _CURRENT_DECOMPOSITION_TABLE:
|
||||
|
|
|
|||
|
|
@ -76,6 +76,12 @@ class OpStrategy(StrategyType):
|
|||
strategy_list_str = ", ".join([str(strategy) for strategy in self.strategies])
|
||||
return f"OpStrategy: [{strategy_list_str}]"
|
||||
|
||||
def max_num_shards(self) -> int:
|
||||
"""
|
||||
Returns the max number of shards across all placement strategies
|
||||
"""
|
||||
return max([strategy.output_spec.num_shards for strategy in self.strategies])
|
||||
|
||||
|
||||
class TupleStrategy(StrategyType):
|
||||
"""
|
||||
|
|
@ -204,6 +210,19 @@ class OpSchema:
|
|||
DTensorSpec, _rebuild_tensor_from_dtensor_meta, self.kwargs_schema
|
||||
)
|
||||
|
||||
def _inplace_rewrap_schema_suggestion(self, origin_schema: "OpSchema") -> None:
|
||||
suggestion_args_spec = self.args_spec
|
||||
new_arg_schema: List[object] = []
|
||||
idx_of_args_spec = 0
|
||||
for arg in origin_schema.args_schema:
|
||||
if isinstance(arg, DTensorSpec):
|
||||
new_arg_schema.append(suggestion_args_spec[idx_of_args_spec])
|
||||
idx_of_args_spec += 1
|
||||
else:
|
||||
new_arg_schema.append(arg)
|
||||
self.args_schema = tuple(new_arg_schema)
|
||||
self.kwargs_schema = origin_schema.kwargs_schema
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputSharding:
|
||||
|
|
|
|||
|
|
@ -13,22 +13,6 @@ def _replace_char_in_str(string: str, new_char: str, idx: int) -> str:
|
|||
return string[:idx] + new_char + string[idx + 1 :]
|
||||
|
||||
|
||||
def _inplace_rewrap_schema_suggestion(
|
||||
suggestion: OpSchema, input_schema: OpSchema
|
||||
) -> None:
|
||||
suggestion_args_spec = suggestion.args_spec
|
||||
new_arg_schema: List[object] = []
|
||||
idx_of_args_spec = 0
|
||||
for arg in input_schema.args_schema:
|
||||
if isinstance(arg, DTensorSpec):
|
||||
new_arg_schema.append(suggestion_args_spec[idx_of_args_spec])
|
||||
idx_of_args_spec += 1
|
||||
else:
|
||||
new_arg_schema.append(arg)
|
||||
suggestion.args_schema = tuple(new_arg_schema)
|
||||
suggestion.kwargs_schema = input_schema.kwargs_schema
|
||||
|
||||
|
||||
def _gen_reshard_suggestions(
|
||||
op_schema: OpSchema,
|
||||
input_dims: List[str],
|
||||
|
|
@ -48,7 +32,7 @@ def _gen_reshard_suggestions(
|
|||
)
|
||||
)
|
||||
suggested_schema = OpSchema(op_schema.func_schema, tuple(suggested_arg_specs), {})
|
||||
_inplace_rewrap_schema_suggestion(suggested_schema, op_schema)
|
||||
suggested_schema._inplace_rewrap_schema_suggestion(op_schema)
|
||||
return OutputSharding(
|
||||
None,
|
||||
schema_suggestions=[suggested_schema],
|
||||
|
|
@ -350,7 +334,7 @@ def reduction_rule(
|
|||
input_spec.mesh, reshard_dim_map, [], tensor_meta=input_spec.tensor_meta
|
||||
)
|
||||
schema_suggestion = OpSchema(op_schema.func_schema, (no_partial_spec,), {})
|
||||
_inplace_rewrap_schema_suggestion(schema_suggestion, op_schema)
|
||||
schema_suggestion._inplace_rewrap_schema_suggestion(op_schema)
|
||||
return OutputSharding(
|
||||
output_spec=None, schema_suggestions=[schema_suggestion]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
|
||||
import torch
|
||||
|
||||
from torch.distributed._tensor.ops.common_rules import (
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from typing import cast, List, Optional, Sequence, Tuple
|
||||
from typing import cast, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -11,38 +11,92 @@ from torch.distributed._tensor.api import (
|
|||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
|
||||
from torch.distributed._tensor.device_mesh import DeviceMesh
|
||||
from torch.distributed._tensor.op_schema import (
|
||||
OpSchema,
|
||||
OpStrategy,
|
||||
OutputSharding,
|
||||
PlacementStrategy,
|
||||
StrategyType,
|
||||
)
|
||||
from torch.distributed._tensor.ops.common_rules import pointwise_rule
|
||||
from torch.distributed._tensor.ops.utils import normalize_dim, prod, register_prop_rule
|
||||
from torch.distributed._tensor.ops.utils import (
|
||||
normalize_dim,
|
||||
prod,
|
||||
register_op_strategy,
|
||||
register_prop_rule,
|
||||
)
|
||||
from torch.fx import Node
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
# NOTE: the default propagation rule should apply for
|
||||
# any operator that does not return a DTensor, i.e.
|
||||
# for operators that only returns int/float/bool, we by
|
||||
# default still propagate the spec, this is to ensure
|
||||
# that we only return None for the case where the sharding
|
||||
# propagation failed, and we should do auto-redistribute
|
||||
def default_prop_rule(op_schema: OpSchema) -> OutputSharding:
|
||||
# by default prop the first arg spec
|
||||
return OutputSharding(op_schema.args_spec[0])
|
||||
|
||||
|
||||
def prop_create_like(op_schema: OpSchema) -> OutputSharding:
|
||||
# For operators that create tensors with same shape as input but
|
||||
# with specific content that does not depend on the input, we
|
||||
# can propagate Sharding, but we have to make sure we move from
|
||||
# partial to replicated.
|
||||
input_spec = op_schema.args_spec[0]
|
||||
output_spec = DTensorSpec(
|
||||
mesh=input_spec.mesh,
|
||||
placements=tuple(
|
||||
Replicate() if isinstance(p, _Partial) else p for p in input_spec.placements
|
||||
),
|
||||
@register_op_strategy(
|
||||
[
|
||||
aten._to_copy.default,
|
||||
aten.clone.default,
|
||||
aten.contiguous.default,
|
||||
aten.copy_.default,
|
||||
aten.detach.default,
|
||||
aten.new_empty_strided.default, # TODO: re-think new_empty_strided
|
||||
]
|
||||
)
|
||||
def default_strategy(
|
||||
node: Node, mesh: DeviceMesh, node_to_strategy: Dict[Node, StrategyType]
|
||||
) -> StrategyType:
|
||||
# Default strategy by default just propagate the first input strategy
|
||||
select_strategy = node_to_strategy[node.all_input_nodes[0]]
|
||||
assert isinstance(select_strategy, OpStrategy)
|
||||
return OpStrategy(
|
||||
[
|
||||
PlacementStrategy(arg_strategy.output_spec)
|
||||
for arg_strategy in select_strategy.strategies
|
||||
]
|
||||
)
|
||||
return OutputSharding(output_spec=output_spec)
|
||||
|
||||
|
||||
@register_op_strategy(
|
||||
[
|
||||
aten.empty_like.default,
|
||||
aten.fill_.Scalar,
|
||||
aten.full_like.default,
|
||||
aten.ones_like.default,
|
||||
aten.zero_.default,
|
||||
aten.zeros_like.default,
|
||||
]
|
||||
)
|
||||
def create_like_strategy(
|
||||
node: Node, mesh: DeviceMesh, node_to_strategy: Dict[Node, StrategyType]
|
||||
) -> StrategyType:
|
||||
# create_like_strategy deals with ops that creating tensors with same
|
||||
# shape as input, but with specific content that does not depend on
|
||||
# the input, we can propagate sharding, but we have to make sure we
|
||||
# move from partial to replicated.
|
||||
select_strategy = node_to_strategy[node.all_input_nodes[0]]
|
||||
create_like_strategy = OpStrategy([])
|
||||
assert isinstance(select_strategy, OpStrategy)
|
||||
for arg_strategy in select_strategy.strategies:
|
||||
arg_spec = arg_strategy.output_spec
|
||||
if arg_spec.sums:
|
||||
# if the arg_spec have partial, accept partial
|
||||
# in the input_specs but output replicate for
|
||||
# those corresponding mesh dims
|
||||
output_spec = DTensorSpec(
|
||||
mesh=arg_spec.mesh,
|
||||
placements=tuple(
|
||||
Replicate() if isinstance(p, _Partial) else p
|
||||
for p in arg_spec.placements
|
||||
),
|
||||
)
|
||||
create_like_strategy.strategies.append(
|
||||
PlacementStrategy(output_spec=output_spec, input_specs=(arg_spec,))
|
||||
)
|
||||
|
||||
else:
|
||||
create_like_strategy.strategies.append(PlacementStrategy(arg_spec))
|
||||
|
||||
return create_like_strategy
|
||||
|
||||
|
||||
@register_prop_rule(aten._local_scalar_dense.default)
|
||||
|
|
@ -85,36 +139,12 @@ def non_tensor_prop_rule(op_schema: OpSchema) -> OutputSharding:
|
|||
return OutputSharding(output_spec=None)
|
||||
|
||||
|
||||
default_prop_ops = [
|
||||
aten._to_copy.default,
|
||||
aten.clone.default,
|
||||
aten.contiguous.default,
|
||||
aten.copy_.default,
|
||||
aten.detach.default,
|
||||
aten.new_empty_strided.default,
|
||||
]
|
||||
|
||||
create_like_ops = [
|
||||
aten.empty_like.default,
|
||||
aten.fill_.Scalar,
|
||||
aten.full_like.default,
|
||||
aten.ones_like.default,
|
||||
aten.zero_.default,
|
||||
aten.zeros_like.default,
|
||||
]
|
||||
|
||||
new_factory_ops = [
|
||||
aten.new_full.default,
|
||||
aten.new_ones.default,
|
||||
aten.new_zeros.default,
|
||||
]
|
||||
|
||||
for op in default_prop_ops:
|
||||
register_prop_rule(op)(default_prop_rule)
|
||||
|
||||
for op in create_like_ops:
|
||||
register_prop_rule(op)(prop_create_like)
|
||||
|
||||
for op in new_factory_ops:
|
||||
register_prop_rule(op)(new_factory_rule)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,15 +7,6 @@ import torch
|
|||
from torch.distributed._tensor.api import DTensor
|
||||
|
||||
|
||||
# pyre-fixme[3]: Return type must be annotated.
|
||||
# pyre-fixme[2]: Parameter must be annotated.
|
||||
def unwrap_single_placement(e):
|
||||
if not isinstance(e, DTensor):
|
||||
return None
|
||||
assert len(e.placements) == 1, "more than one placement!"
|
||||
return e.placements[0]
|
||||
|
||||
|
||||
# convenient wrapper to register sharding propagation rules
|
||||
# pyre-fixme[3]: Return type must be annotated.
|
||||
# pyre-fixme[2]: Parameter must be annotated.
|
||||
|
|
@ -32,6 +23,19 @@ def register_prop_rule(op):
|
|||
return wrapper
|
||||
|
||||
|
||||
def register_op_strategy(op):
|
||||
# pyre-fixme[53]: Captured variable `func` is not annotated.
|
||||
# pyre-fixme[3]: Return type must be annotated.
|
||||
# pyre-fixme[2]: Parameter must be annotated.
|
||||
def wrapper(impl):
|
||||
overloads = op if isinstance(op, list) else [op]
|
||||
for overload in overloads:
|
||||
DTensor._propagator.register_op_strategy(overload, impl)
|
||||
return impl
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def as_list(
|
||||
x: Union[List[object], object]
|
||||
# pyre-fixme[11]: Annotation `immutable_list` is not defined as a type.
|
||||
|
|
|
|||
|
|
@ -412,6 +412,14 @@ class DTensorSpec:
|
|||
raise ValueError("tensor_meta is not set")
|
||||
return len(self.tensor_meta.shape)
|
||||
|
||||
@property
|
||||
def num_shards(self) -> int:
|
||||
num_shards = 1
|
||||
for i, placement in enumerate(self.placements):
|
||||
if placement.is_shard():
|
||||
num_shards *= self.mesh.size(i)
|
||||
return num_shards
|
||||
|
||||
@property
|
||||
def dim_map(self) -> List[int]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -4,9 +4,19 @@ import torch
|
|||
import torch.distributed._tensor.api as dtensor
|
||||
from torch._ops import OpOverload
|
||||
from torch._subclasses import FakeTensorMode
|
||||
from torch.distributed._tensor.op_schema import DTensorSpec, OpSchema, OutputSharding
|
||||
from torch.distributed._tensor.device_mesh import DeviceMesh
|
||||
from torch.distributed._tensor.op_schema import (
|
||||
DTensorSpec,
|
||||
OpSchema,
|
||||
OpStrategy,
|
||||
OutputSharding,
|
||||
OutputSpecType,
|
||||
PlacementStrategy,
|
||||
StrategyType,
|
||||
)
|
||||
from torch.fx import Node
|
||||
from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.utils._pytree import tree_flatten, tree_map
|
||||
|
||||
"""
|
||||
Print information on ops input shape and sharding for debugging purposes.
|
||||
|
|
@ -21,6 +31,10 @@ def unwrap_schema(e: object) -> object:
|
|||
class ShardingPropagator:
|
||||
def __init__(self) -> None:
|
||||
self.op_to_rules: Dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {}
|
||||
self.op_strategy_funcs: Dict[
|
||||
OpOverload,
|
||||
Callable[[Node, DeviceMesh, Dict[Node, StrategyType]], StrategyType],
|
||||
] = {}
|
||||
|
||||
def register_sharding_prop_rule(
|
||||
self, op_overload: OpOverload, rule_func: Callable[[OpSchema], OutputSharding]
|
||||
|
|
@ -30,6 +44,16 @@ class ShardingPropagator:
|
|||
"""
|
||||
self.op_to_rules[op_overload] = rule_func
|
||||
|
||||
def register_op_strategy(
|
||||
self,
|
||||
op_overload: OpOverload,
|
||||
rule_func: Callable[[Node, DeviceMesh, Dict[Node, StrategyType]], StrategyType],
|
||||
):
|
||||
"""
|
||||
Register a sharding strategy generator for an operator.
|
||||
"""
|
||||
self.op_strategy_funcs[op_overload] = rule_func
|
||||
|
||||
def prepare_op_schema(
|
||||
self, op_call: OpOverload, args: Tuple[object, ...], kwargs: Dict[str, object]
|
||||
) -> OpSchema:
|
||||
|
|
@ -54,6 +78,103 @@ class ShardingPropagator:
|
|||
|
||||
return op_schema
|
||||
|
||||
def propagate(self, op_overload: OpOverload, op_schema: OpSchema) -> OutputSharding:
|
||||
if op_overload in self.op_strategy_funcs:
|
||||
# generate op strategy for the op, this is done by propagating
|
||||
# the sharding in the graph.
|
||||
op_gm = self._prepare_op_graph(op_overload, op_schema)
|
||||
if op_gm is None:
|
||||
return OutputSharding(None, [op_schema])
|
||||
|
||||
flat_args_sharding, _ = tree_flatten(
|
||||
[op_schema.args_schema, op_schema.kwargs_schema]
|
||||
)
|
||||
node_to_strategy: Dict[Node, StrategyType] = {}
|
||||
output_node = None
|
||||
out_node_strategy = None
|
||||
mesh = flat_args_sharding[0].mesh
|
||||
placeholder_idx = 0
|
||||
for node in op_gm.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
# set sharding to placeholders if it's Node
|
||||
if isinstance(flat_args_sharding[placeholder_idx], DTensorSpec):
|
||||
strategy = PlacementStrategy(
|
||||
flat_args_sharding[placeholder_idx]
|
||||
)
|
||||
# for eager execution, inputs only have one fixed sharding
|
||||
node_to_strategy[node] = OpStrategy([strategy])
|
||||
placeholder_idx += 1
|
||||
elif node.op == "call_function":
|
||||
if isinstance(node.target, OpOverload):
|
||||
op_strategy_func = self.op_strategy_funcs[op_overload]
|
||||
out_strategies = op_strategy_func(node, mesh, node_to_strategy)
|
||||
node_to_strategy[node] = out_strategies
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported function: {node.target}"
|
||||
)
|
||||
elif node.op == "output":
|
||||
output_node = node.args[0]
|
||||
out_node_strategy = node_to_strategy[output_node[0]]
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported node type: {node.op}")
|
||||
|
||||
# NOTE: This had the assumption we only have one call_function op in the
|
||||
# op graph, we need to harden this logic when there're decomposed ops.
|
||||
assert isinstance(out_node_strategy, OpStrategy)
|
||||
# we take the first strategy for now
|
||||
# TODO: add a min cost selection logic
|
||||
output_strategy = out_node_strategy.strategies[0]
|
||||
needs_redistribute = False
|
||||
expected_input_specs = []
|
||||
for idx, input_spec in enumerate(op_schema.args_spec):
|
||||
desired_spec = (
|
||||
output_strategy.output_spec
|
||||
if output_strategy.input_specs is None
|
||||
else output_strategy.input_specs[idx]
|
||||
)
|
||||
expected_input_specs.append(desired_spec)
|
||||
if input_spec != desired_spec:
|
||||
needs_redistribute = True
|
||||
|
||||
if needs_redistribute:
|
||||
suggestion_schema = OpSchema(
|
||||
op_schema.func_schema, tuple(expected_input_specs), {}
|
||||
)
|
||||
suggestion_schema._inplace_rewrap_schema_suggestion(op_schema)
|
||||
else:
|
||||
suggestion_schema = op_schema
|
||||
|
||||
output_sharding = OutputSharding(
|
||||
output_strategy.output_spec,
|
||||
[suggestion_schema],
|
||||
)
|
||||
if output_node is not None:
|
||||
self._wrap_output_spec_meta(output_sharding.output_spec, output_node)
|
||||
return output_sharding
|
||||
|
||||
elif op_overload in self.op_to_rules:
|
||||
return self.propagate_op_sharding(op_overload, op_schema)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Operator {op_overload} does not have a sharding strategy registered."
|
||||
)
|
||||
|
||||
def _wrap_output_spec_meta(
|
||||
self, output_spec: OutputSpecType, output_nodes: Node
|
||||
) -> None:
|
||||
"""
|
||||
Wrap the output_spec with the metadata from the output node.
|
||||
"""
|
||||
if output_spec is not None:
|
||||
assert isinstance(output_nodes, (tuple, list))
|
||||
if isinstance(output_spec, DTensorSpec):
|
||||
output_spec.tensor_meta = output_nodes[0].meta["tensor_meta"]
|
||||
elif isinstance(output_spec, (tuple, list)):
|
||||
for i, spec in enumerate(output_spec):
|
||||
if isinstance(spec, DTensorSpec):
|
||||
spec.tensor_meta = output_nodes[i].meta["tensor_meta"]
|
||||
|
||||
def propagate_op_sharding(
|
||||
self, op_overload: OpOverload, op_schema: OpSchema
|
||||
) -> OutputSharding:
|
||||
|
|
@ -61,19 +182,17 @@ class ShardingPropagator:
|
|||
Propagate the sharding for an operator given the op_schema.
|
||||
"""
|
||||
# first we propagate the tensor metadata
|
||||
output_node = self._propagate_tensor_meta(op_overload, op_schema)
|
||||
output_node = None
|
||||
op_gm = self._prepare_op_graph(op_overload, op_schema)
|
||||
if op_gm is not None:
|
||||
for node in op_gm.graph.nodes:
|
||||
if node.op == "output":
|
||||
output_node = node.args[0]
|
||||
|
||||
# then we propagate the sharding
|
||||
sharding_prop_func = self.op_to_rules.get(op_overload, None)
|
||||
sharding_prop_func = self.op_to_rules[op_overload]
|
||||
|
||||
if sharding_prop_func is None:
|
||||
# step 1. If there's not even one sharding rule
|
||||
# implemented for the operator, we error out.
|
||||
raise NotImplementedError(
|
||||
f"Operator {op_overload} does not have a DistributedTensor rule registered."
|
||||
)
|
||||
|
||||
# step 2. there's sharding propagation rule, run
|
||||
# step 1. there's sharding propagation rule, run
|
||||
# sharding propagation to get the output sharding
|
||||
try:
|
||||
output_sharding = sharding_prop_func(op_schema)
|
||||
|
|
@ -86,7 +205,7 @@ class ShardingPropagator:
|
|||
f"Error: {e}"
|
||||
) from e
|
||||
|
||||
# step 3. if can't get output_spec from sharding
|
||||
# step 2. if can't get output_spec from sharding
|
||||
# propagation (i.e. no rules apply for input
|
||||
# placements), we return the output sharding
|
||||
# with schema suggestions, which can be used to
|
||||
|
|
@ -110,8 +229,6 @@ class ShardingPropagator:
|
|||
# to get an eligible input, which we will pick a
|
||||
# schema suggestion base on the redistribute cost.
|
||||
# For now we simply pick the first suggestion.
|
||||
# TODO: implement full auto distribute with a
|
||||
# simple cost estimation model
|
||||
suggested_input_schema = output_sharding.schema_suggestions[0]
|
||||
# run sharding propagation again with suggested schema
|
||||
propagation_res = sharding_prop_func(suggested_input_schema)
|
||||
|
|
@ -126,24 +243,15 @@ class ShardingPropagator:
|
|||
|
||||
# associate the output sharding with the output metadata
|
||||
if output_node is not None:
|
||||
output_nodes = output_node.args[0]
|
||||
output_spec = output_sharding.output_spec
|
||||
if output_spec is not None:
|
||||
assert isinstance(output_nodes, (tuple, list))
|
||||
if isinstance(output_spec, DTensorSpec):
|
||||
output_spec.tensor_meta = output_nodes[0].meta["tensor_meta"]
|
||||
elif isinstance(output_spec, (tuple, list)):
|
||||
for i, spec in enumerate(output_spec):
|
||||
if isinstance(spec, DTensorSpec):
|
||||
spec.tensor_meta = output_nodes[i].meta["tensor_meta"]
|
||||
self._wrap_output_spec_meta(output_sharding.output_spec, output_node)
|
||||
|
||||
return output_sharding
|
||||
|
||||
def _propagate_tensor_meta(
|
||||
def _prepare_op_graph(
|
||||
self,
|
||||
op_overload: OpOverload,
|
||||
op_schema: OpSchema,
|
||||
) -> Optional[torch.fx.Node]:
|
||||
) -> Optional[torch.fx.GraphModule]:
|
||||
# right now we only use the graph for metadata prop, but next we will use
|
||||
# the graph to do sharding prop together
|
||||
|
||||
|
|
@ -163,11 +271,7 @@ class ShardingPropagator:
|
|||
fake_kwargs = op_schema.gen_fake_kwargs()
|
||||
g = get_isolated_graphmodule(op_overload, fake_args, fake_kwargs)
|
||||
|
||||
output = None
|
||||
for node in g.graph.nodes:
|
||||
if node.op == "output":
|
||||
output = node
|
||||
return output
|
||||
return g
|
||||
|
||||
|
||||
class _CachingPropagator(ShardingPropagator):
|
||||
|
|
@ -176,18 +280,16 @@ class _CachingPropagator(ShardingPropagator):
|
|||
This is currently experimental for Tensor Parallel usage.
|
||||
"""
|
||||
|
||||
def __init__(self, op_to_rules=None) -> None:
|
||||
def __init__(self, propagator: ShardingPropagator) -> None:
|
||||
super().__init__()
|
||||
if op_to_rules is not None:
|
||||
self.op_to_rules = op_to_rules
|
||||
self.op_to_rules = propagator.op_to_rules
|
||||
self.op_strategy_funcs = propagator.op_strategy_funcs
|
||||
|
||||
# cache table for sharding propagation results, we might need to
|
||||
# limit the size of the cache table in the future
|
||||
self.cached_prop_results: Dict[OpSchema, OutputSharding] = {}
|
||||
|
||||
def propagate_op_sharding(
|
||||
self, op_overload: OpOverload, op_schema: OpSchema
|
||||
) -> OutputSharding:
|
||||
def propagate(self, op_overload: OpOverload, op_schema: OpSchema) -> OutputSharding:
|
||||
"""
|
||||
Propagate the sharding for an operator given the op_schema.
|
||||
Cache the propagation results to avoid running propagation again.
|
||||
|
|
@ -196,7 +298,7 @@ class _CachingPropagator(ShardingPropagator):
|
|||
return self.cached_prop_results[op_schema]
|
||||
else:
|
||||
# call DTensor's propagate_op_sharding to get the prop result
|
||||
output_sharding = super().propagate_op_sharding(op_overload, op_schema)
|
||||
output_sharding = super().propagate(op_overload, op_schema)
|
||||
# update cached table
|
||||
self.cached_prop_results[op_schema] = output_sharding
|
||||
return output_sharding
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ __all__ = [
|
|||
|
||||
# switch the DTensor propagator to use the caching propagator to speed up
|
||||
# the TP eager execution time.
|
||||
DTensor._propagator = _CachingPropagator(DTensor._propagator.op_to_rules)
|
||||
DTensor._propagator = _CachingPropagator(DTensor._propagator)
|
||||
|
||||
def parallelize_module( # type: ignore[return]
|
||||
module: nn.Module,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user