[dtensor] tensor ops to use strategy based sharding prop (#100607)

This is the first series of PR that adopts operator impls to use a
strategy based approach, each op utilizes OpStrategy and PlacementStrategy
to generate their own strategy. By utilizing the strategy based
approach along with the op graph, we could enable more advanced op
implementation (decomp is possible), and turn the sharding prop to be
more like a contraint satisfication problem.

This PR alone only adds some basic tensor op strategies, and it directly
works on the op graph that was used for metadata propagation. The tensor ops
added in this PR mainly follows one of the arg strategy. The next set of
PRs would add more op strategies to other ops.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100607
Approved by: https://github.com/XilunWu
This commit is contained in:
Wanchao Liang 2023-05-10 23:45:15 +00:00 committed by PyTorch MergeBot
parent d1f0c8e2d0
commit a1aa32e204
9 changed files with 266 additions and 118 deletions

View File

@ -145,7 +145,7 @@ def _operator_dispatch(
# unwrap the args/kwargs schema # unwrap the args/kwargs schema
op_schema = sharding_propagator.prepare_op_schema(op_call, args, kwargs) op_schema = sharding_propagator.prepare_op_schema(op_call, args, kwargs)
output_sharding = sharding_propagator.propagate_op_sharding(op_call, op_schema) output_sharding = sharding_propagator.propagate(op_call, op_schema)
# first we need to lift some private aten aliases to public calls # first we need to lift some private aten aliases to public calls
if op_call in _CURRENT_DECOMPOSITION_TABLE: if op_call in _CURRENT_DECOMPOSITION_TABLE:

View File

@ -76,6 +76,12 @@ class OpStrategy(StrategyType):
strategy_list_str = ", ".join([str(strategy) for strategy in self.strategies]) strategy_list_str = ", ".join([str(strategy) for strategy in self.strategies])
return f"OpStrategy: [{strategy_list_str}]" return f"OpStrategy: [{strategy_list_str}]"
def max_num_shards(self) -> int:
"""
Returns the max number of shards across all placement strategies
"""
return max([strategy.output_spec.num_shards for strategy in self.strategies])
class TupleStrategy(StrategyType): class TupleStrategy(StrategyType):
""" """
@ -204,6 +210,19 @@ class OpSchema:
DTensorSpec, _rebuild_tensor_from_dtensor_meta, self.kwargs_schema DTensorSpec, _rebuild_tensor_from_dtensor_meta, self.kwargs_schema
) )
def _inplace_rewrap_schema_suggestion(self, origin_schema: "OpSchema") -> None:
suggestion_args_spec = self.args_spec
new_arg_schema: List[object] = []
idx_of_args_spec = 0
for arg in origin_schema.args_schema:
if isinstance(arg, DTensorSpec):
new_arg_schema.append(suggestion_args_spec[idx_of_args_spec])
idx_of_args_spec += 1
else:
new_arg_schema.append(arg)
self.args_schema = tuple(new_arg_schema)
self.kwargs_schema = origin_schema.kwargs_schema
@dataclass @dataclass
class OutputSharding: class OutputSharding:

View File

@ -13,22 +13,6 @@ def _replace_char_in_str(string: str, new_char: str, idx: int) -> str:
return string[:idx] + new_char + string[idx + 1 :] return string[:idx] + new_char + string[idx + 1 :]
def _inplace_rewrap_schema_suggestion(
suggestion: OpSchema, input_schema: OpSchema
) -> None:
suggestion_args_spec = suggestion.args_spec
new_arg_schema: List[object] = []
idx_of_args_spec = 0
for arg in input_schema.args_schema:
if isinstance(arg, DTensorSpec):
new_arg_schema.append(suggestion_args_spec[idx_of_args_spec])
idx_of_args_spec += 1
else:
new_arg_schema.append(arg)
suggestion.args_schema = tuple(new_arg_schema)
suggestion.kwargs_schema = input_schema.kwargs_schema
def _gen_reshard_suggestions( def _gen_reshard_suggestions(
op_schema: OpSchema, op_schema: OpSchema,
input_dims: List[str], input_dims: List[str],
@ -48,7 +32,7 @@ def _gen_reshard_suggestions(
) )
) )
suggested_schema = OpSchema(op_schema.func_schema, tuple(suggested_arg_specs), {}) suggested_schema = OpSchema(op_schema.func_schema, tuple(suggested_arg_specs), {})
_inplace_rewrap_schema_suggestion(suggested_schema, op_schema) suggested_schema._inplace_rewrap_schema_suggestion(op_schema)
return OutputSharding( return OutputSharding(
None, None,
schema_suggestions=[suggested_schema], schema_suggestions=[suggested_schema],
@ -350,7 +334,7 @@ def reduction_rule(
input_spec.mesh, reshard_dim_map, [], tensor_meta=input_spec.tensor_meta input_spec.mesh, reshard_dim_map, [], tensor_meta=input_spec.tensor_meta
) )
schema_suggestion = OpSchema(op_schema.func_schema, (no_partial_spec,), {}) schema_suggestion = OpSchema(op_schema.func_schema, (no_partial_spec,), {})
_inplace_rewrap_schema_suggestion(schema_suggestion, op_schema) schema_suggestion._inplace_rewrap_schema_suggestion(op_schema)
return OutputSharding( return OutputSharding(
output_spec=None, schema_suggestions=[schema_suggestion] output_spec=None, schema_suggestions=[schema_suggestion]
) )

View File

@ -1,4 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates # Copyright (c) Meta Platforms, Inc. and affiliates
import torch import torch
from torch.distributed._tensor.ops.common_rules import ( from torch.distributed._tensor.ops.common_rules import (

View File

@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates # Copyright (c) Meta Platforms, Inc. and affiliates
from typing import cast, List, Optional, Sequence, Tuple from typing import cast, Dict, List, Optional, Sequence, Tuple
import torch import torch
@ -11,38 +11,92 @@ from torch.distributed._tensor.api import (
Replicate, Replicate,
Shard, Shard,
) )
from torch.distributed._tensor.op_schema import OpSchema, OutputSharding from torch.distributed._tensor.device_mesh import DeviceMesh
from torch.distributed._tensor.op_schema import (
OpSchema,
OpStrategy,
OutputSharding,
PlacementStrategy,
StrategyType,
)
from torch.distributed._tensor.ops.common_rules import pointwise_rule from torch.distributed._tensor.ops.common_rules import pointwise_rule
from torch.distributed._tensor.ops.utils import normalize_dim, prod, register_prop_rule from torch.distributed._tensor.ops.utils import (
normalize_dim,
prod,
register_op_strategy,
register_prop_rule,
)
from torch.fx import Node
aten = torch.ops.aten aten = torch.ops.aten
# NOTE: the default propagation rule should apply for @register_op_strategy(
# any operator that does not return a DTensor, i.e. [
# for operators that only returns int/float/bool, we by aten._to_copy.default,
# default still propagate the spec, this is to ensure aten.clone.default,
# that we only return None for the case where the sharding aten.contiguous.default,
# propagation failed, and we should do auto-redistribute aten.copy_.default,
def default_prop_rule(op_schema: OpSchema) -> OutputSharding: aten.detach.default,
# by default prop the first arg spec aten.new_empty_strided.default, # TODO: re-think new_empty_strided
return OutputSharding(op_schema.args_spec[0]) ]
)
def default_strategy(
def prop_create_like(op_schema: OpSchema) -> OutputSharding: node: Node, mesh: DeviceMesh, node_to_strategy: Dict[Node, StrategyType]
# For operators that create tensors with same shape as input but ) -> StrategyType:
# with specific content that does not depend on the input, we # Default strategy by default just propagate the first input strategy
# can propagate Sharding, but we have to make sure we move from select_strategy = node_to_strategy[node.all_input_nodes[0]]
# partial to replicated. assert isinstance(select_strategy, OpStrategy)
input_spec = op_schema.args_spec[0] return OpStrategy(
output_spec = DTensorSpec( [
mesh=input_spec.mesh, PlacementStrategy(arg_strategy.output_spec)
placements=tuple( for arg_strategy in select_strategy.strategies
Replicate() if isinstance(p, _Partial) else p for p in input_spec.placements ]
),
) )
return OutputSharding(output_spec=output_spec)
@register_op_strategy(
[
aten.empty_like.default,
aten.fill_.Scalar,
aten.full_like.default,
aten.ones_like.default,
aten.zero_.default,
aten.zeros_like.default,
]
)
def create_like_strategy(
node: Node, mesh: DeviceMesh, node_to_strategy: Dict[Node, StrategyType]
) -> StrategyType:
# create_like_strategy deals with ops that creating tensors with same
# shape as input, but with specific content that does not depend on
# the input, we can propagate sharding, but we have to make sure we
# move from partial to replicated.
select_strategy = node_to_strategy[node.all_input_nodes[0]]
create_like_strategy = OpStrategy([])
assert isinstance(select_strategy, OpStrategy)
for arg_strategy in select_strategy.strategies:
arg_spec = arg_strategy.output_spec
if arg_spec.sums:
# if the arg_spec have partial, accept partial
# in the input_specs but output replicate for
# those corresponding mesh dims
output_spec = DTensorSpec(
mesh=arg_spec.mesh,
placements=tuple(
Replicate() if isinstance(p, _Partial) else p
for p in arg_spec.placements
),
)
create_like_strategy.strategies.append(
PlacementStrategy(output_spec=output_spec, input_specs=(arg_spec,))
)
else:
create_like_strategy.strategies.append(PlacementStrategy(arg_spec))
return create_like_strategy
@register_prop_rule(aten._local_scalar_dense.default) @register_prop_rule(aten._local_scalar_dense.default)
@ -85,36 +139,12 @@ def non_tensor_prop_rule(op_schema: OpSchema) -> OutputSharding:
return OutputSharding(output_spec=None) return OutputSharding(output_spec=None)
default_prop_ops = [
aten._to_copy.default,
aten.clone.default,
aten.contiguous.default,
aten.copy_.default,
aten.detach.default,
aten.new_empty_strided.default,
]
create_like_ops = [
aten.empty_like.default,
aten.fill_.Scalar,
aten.full_like.default,
aten.ones_like.default,
aten.zero_.default,
aten.zeros_like.default,
]
new_factory_ops = [ new_factory_ops = [
aten.new_full.default, aten.new_full.default,
aten.new_ones.default, aten.new_ones.default,
aten.new_zeros.default, aten.new_zeros.default,
] ]
for op in default_prop_ops:
register_prop_rule(op)(default_prop_rule)
for op in create_like_ops:
register_prop_rule(op)(prop_create_like)
for op in new_factory_ops: for op in new_factory_ops:
register_prop_rule(op)(new_factory_rule) register_prop_rule(op)(new_factory_rule)

View File

@ -7,15 +7,6 @@ import torch
from torch.distributed._tensor.api import DTensor from torch.distributed._tensor.api import DTensor
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def unwrap_single_placement(e):
if not isinstance(e, DTensor):
return None
assert len(e.placements) == 1, "more than one placement!"
return e.placements[0]
# convenient wrapper to register sharding propagation rules # convenient wrapper to register sharding propagation rules
# pyre-fixme[3]: Return type must be annotated. # pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated. # pyre-fixme[2]: Parameter must be annotated.
@ -32,6 +23,19 @@ def register_prop_rule(op):
return wrapper return wrapper
def register_op_strategy(op):
# pyre-fixme[53]: Captured variable `func` is not annotated.
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def wrapper(impl):
overloads = op if isinstance(op, list) else [op]
for overload in overloads:
DTensor._propagator.register_op_strategy(overload, impl)
return impl
return wrapper
def as_list( def as_list(
x: Union[List[object], object] x: Union[List[object], object]
# pyre-fixme[11]: Annotation `immutable_list` is not defined as a type. # pyre-fixme[11]: Annotation `immutable_list` is not defined as a type.

View File

@ -412,6 +412,14 @@ class DTensorSpec:
raise ValueError("tensor_meta is not set") raise ValueError("tensor_meta is not set")
return len(self.tensor_meta.shape) return len(self.tensor_meta.shape)
@property
def num_shards(self) -> int:
num_shards = 1
for i, placement in enumerate(self.placements):
if placement.is_shard():
num_shards *= self.mesh.size(i)
return num_shards
@property @property
def dim_map(self) -> List[int]: def dim_map(self) -> List[int]:
""" """

View File

@ -4,9 +4,19 @@ import torch
import torch.distributed._tensor.api as dtensor import torch.distributed._tensor.api as dtensor
from torch._ops import OpOverload from torch._ops import OpOverload
from torch._subclasses import FakeTensorMode from torch._subclasses import FakeTensorMode
from torch.distributed._tensor.op_schema import DTensorSpec, OpSchema, OutputSharding from torch.distributed._tensor.device_mesh import DeviceMesh
from torch.distributed._tensor.op_schema import (
DTensorSpec,
OpSchema,
OpStrategy,
OutputSharding,
OutputSpecType,
PlacementStrategy,
StrategyType,
)
from torch.fx import Node
from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_flatten, tree_map
""" """
Print information on ops input shape and sharding for debugging purposes. Print information on ops input shape and sharding for debugging purposes.
@ -21,6 +31,10 @@ def unwrap_schema(e: object) -> object:
class ShardingPropagator: class ShardingPropagator:
def __init__(self) -> None: def __init__(self) -> None:
self.op_to_rules: Dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {} self.op_to_rules: Dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {}
self.op_strategy_funcs: Dict[
OpOverload,
Callable[[Node, DeviceMesh, Dict[Node, StrategyType]], StrategyType],
] = {}
def register_sharding_prop_rule( def register_sharding_prop_rule(
self, op_overload: OpOverload, rule_func: Callable[[OpSchema], OutputSharding] self, op_overload: OpOverload, rule_func: Callable[[OpSchema], OutputSharding]
@ -30,6 +44,16 @@ class ShardingPropagator:
""" """
self.op_to_rules[op_overload] = rule_func self.op_to_rules[op_overload] = rule_func
def register_op_strategy(
self,
op_overload: OpOverload,
rule_func: Callable[[Node, DeviceMesh, Dict[Node, StrategyType]], StrategyType],
):
"""
Register a sharding strategy generator for an operator.
"""
self.op_strategy_funcs[op_overload] = rule_func
def prepare_op_schema( def prepare_op_schema(
self, op_call: OpOverload, args: Tuple[object, ...], kwargs: Dict[str, object] self, op_call: OpOverload, args: Tuple[object, ...], kwargs: Dict[str, object]
) -> OpSchema: ) -> OpSchema:
@ -54,6 +78,103 @@ class ShardingPropagator:
return op_schema return op_schema
def propagate(self, op_overload: OpOverload, op_schema: OpSchema) -> OutputSharding:
if op_overload in self.op_strategy_funcs:
# generate op strategy for the op, this is done by propagating
# the sharding in the graph.
op_gm = self._prepare_op_graph(op_overload, op_schema)
if op_gm is None:
return OutputSharding(None, [op_schema])
flat_args_sharding, _ = tree_flatten(
[op_schema.args_schema, op_schema.kwargs_schema]
)
node_to_strategy: Dict[Node, StrategyType] = {}
output_node = None
out_node_strategy = None
mesh = flat_args_sharding[0].mesh
placeholder_idx = 0
for node in op_gm.graph.nodes:
if node.op == "placeholder":
# set sharding to placeholders if it's Node
if isinstance(flat_args_sharding[placeholder_idx], DTensorSpec):
strategy = PlacementStrategy(
flat_args_sharding[placeholder_idx]
)
# for eager execution, inputs only have one fixed sharding
node_to_strategy[node] = OpStrategy([strategy])
placeholder_idx += 1
elif node.op == "call_function":
if isinstance(node.target, OpOverload):
op_strategy_func = self.op_strategy_funcs[op_overload]
out_strategies = op_strategy_func(node, mesh, node_to_strategy)
node_to_strategy[node] = out_strategies
else:
raise NotImplementedError(
f"Unsupported function: {node.target}"
)
elif node.op == "output":
output_node = node.args[0]
out_node_strategy = node_to_strategy[output_node[0]]
else:
raise NotImplementedError(f"Unsupported node type: {node.op}")
# NOTE: This had the assumption we only have one call_function op in the
# op graph, we need to harden this logic when there're decomposed ops.
assert isinstance(out_node_strategy, OpStrategy)
# we take the first strategy for now
# TODO: add a min cost selection logic
output_strategy = out_node_strategy.strategies[0]
needs_redistribute = False
expected_input_specs = []
for idx, input_spec in enumerate(op_schema.args_spec):
desired_spec = (
output_strategy.output_spec
if output_strategy.input_specs is None
else output_strategy.input_specs[idx]
)
expected_input_specs.append(desired_spec)
if input_spec != desired_spec:
needs_redistribute = True
if needs_redistribute:
suggestion_schema = OpSchema(
op_schema.func_schema, tuple(expected_input_specs), {}
)
suggestion_schema._inplace_rewrap_schema_suggestion(op_schema)
else:
suggestion_schema = op_schema
output_sharding = OutputSharding(
output_strategy.output_spec,
[suggestion_schema],
)
if output_node is not None:
self._wrap_output_spec_meta(output_sharding.output_spec, output_node)
return output_sharding
elif op_overload in self.op_to_rules:
return self.propagate_op_sharding(op_overload, op_schema)
else:
raise NotImplementedError(
f"Operator {op_overload} does not have a sharding strategy registered."
)
def _wrap_output_spec_meta(
self, output_spec: OutputSpecType, output_nodes: Node
) -> None:
"""
Wrap the output_spec with the metadata from the output node.
"""
if output_spec is not None:
assert isinstance(output_nodes, (tuple, list))
if isinstance(output_spec, DTensorSpec):
output_spec.tensor_meta = output_nodes[0].meta["tensor_meta"]
elif isinstance(output_spec, (tuple, list)):
for i, spec in enumerate(output_spec):
if isinstance(spec, DTensorSpec):
spec.tensor_meta = output_nodes[i].meta["tensor_meta"]
def propagate_op_sharding( def propagate_op_sharding(
self, op_overload: OpOverload, op_schema: OpSchema self, op_overload: OpOverload, op_schema: OpSchema
) -> OutputSharding: ) -> OutputSharding:
@ -61,19 +182,17 @@ class ShardingPropagator:
Propagate the sharding for an operator given the op_schema. Propagate the sharding for an operator given the op_schema.
""" """
# first we propagate the tensor metadata # first we propagate the tensor metadata
output_node = self._propagate_tensor_meta(op_overload, op_schema) output_node = None
op_gm = self._prepare_op_graph(op_overload, op_schema)
if op_gm is not None:
for node in op_gm.graph.nodes:
if node.op == "output":
output_node = node.args[0]
# then we propagate the sharding # then we propagate the sharding
sharding_prop_func = self.op_to_rules.get(op_overload, None) sharding_prop_func = self.op_to_rules[op_overload]
if sharding_prop_func is None: # step 1. there's sharding propagation rule, run
# step 1. If there's not even one sharding rule
# implemented for the operator, we error out.
raise NotImplementedError(
f"Operator {op_overload} does not have a DistributedTensor rule registered."
)
# step 2. there's sharding propagation rule, run
# sharding propagation to get the output sharding # sharding propagation to get the output sharding
try: try:
output_sharding = sharding_prop_func(op_schema) output_sharding = sharding_prop_func(op_schema)
@ -86,7 +205,7 @@ class ShardingPropagator:
f"Error: {e}" f"Error: {e}"
) from e ) from e
# step 3. if can't get output_spec from sharding # step 2. if can't get output_spec from sharding
# propagation (i.e. no rules apply for input # propagation (i.e. no rules apply for input
# placements), we return the output sharding # placements), we return the output sharding
# with schema suggestions, which can be used to # with schema suggestions, which can be used to
@ -110,8 +229,6 @@ class ShardingPropagator:
# to get an eligible input, which we will pick a # to get an eligible input, which we will pick a
# schema suggestion base on the redistribute cost. # schema suggestion base on the redistribute cost.
# For now we simply pick the first suggestion. # For now we simply pick the first suggestion.
# TODO: implement full auto distribute with a
# simple cost estimation model
suggested_input_schema = output_sharding.schema_suggestions[0] suggested_input_schema = output_sharding.schema_suggestions[0]
# run sharding propagation again with suggested schema # run sharding propagation again with suggested schema
propagation_res = sharding_prop_func(suggested_input_schema) propagation_res = sharding_prop_func(suggested_input_schema)
@ -126,24 +243,15 @@ class ShardingPropagator:
# associate the output sharding with the output metadata # associate the output sharding with the output metadata
if output_node is not None: if output_node is not None:
output_nodes = output_node.args[0] self._wrap_output_spec_meta(output_sharding.output_spec, output_node)
output_spec = output_sharding.output_spec
if output_spec is not None:
assert isinstance(output_nodes, (tuple, list))
if isinstance(output_spec, DTensorSpec):
output_spec.tensor_meta = output_nodes[0].meta["tensor_meta"]
elif isinstance(output_spec, (tuple, list)):
for i, spec in enumerate(output_spec):
if isinstance(spec, DTensorSpec):
spec.tensor_meta = output_nodes[i].meta["tensor_meta"]
return output_sharding return output_sharding
def _propagate_tensor_meta( def _prepare_op_graph(
self, self,
op_overload: OpOverload, op_overload: OpOverload,
op_schema: OpSchema, op_schema: OpSchema,
) -> Optional[torch.fx.Node]: ) -> Optional[torch.fx.GraphModule]:
# right now we only use the graph for metadata prop, but next we will use # right now we only use the graph for metadata prop, but next we will use
# the graph to do sharding prop together # the graph to do sharding prop together
@ -163,11 +271,7 @@ class ShardingPropagator:
fake_kwargs = op_schema.gen_fake_kwargs() fake_kwargs = op_schema.gen_fake_kwargs()
g = get_isolated_graphmodule(op_overload, fake_args, fake_kwargs) g = get_isolated_graphmodule(op_overload, fake_args, fake_kwargs)
output = None return g
for node in g.graph.nodes:
if node.op == "output":
output = node
return output
class _CachingPropagator(ShardingPropagator): class _CachingPropagator(ShardingPropagator):
@ -176,18 +280,16 @@ class _CachingPropagator(ShardingPropagator):
This is currently experimental for Tensor Parallel usage. This is currently experimental for Tensor Parallel usage.
""" """
def __init__(self, op_to_rules=None) -> None: def __init__(self, propagator: ShardingPropagator) -> None:
super().__init__() super().__init__()
if op_to_rules is not None: self.op_to_rules = propagator.op_to_rules
self.op_to_rules = op_to_rules self.op_strategy_funcs = propagator.op_strategy_funcs
# cache table for sharding propagation results, we might need to # cache table for sharding propagation results, we might need to
# limit the size of the cache table in the future # limit the size of the cache table in the future
self.cached_prop_results: Dict[OpSchema, OutputSharding] = {} self.cached_prop_results: Dict[OpSchema, OutputSharding] = {}
def propagate_op_sharding( def propagate(self, op_overload: OpOverload, op_schema: OpSchema) -> OutputSharding:
self, op_overload: OpOverload, op_schema: OpSchema
) -> OutputSharding:
""" """
Propagate the sharding for an operator given the op_schema. Propagate the sharding for an operator given the op_schema.
Cache the propagation results to avoid running propagation again. Cache the propagation results to avoid running propagation again.
@ -196,7 +298,7 @@ class _CachingPropagator(ShardingPropagator):
return self.cached_prop_results[op_schema] return self.cached_prop_results[op_schema]
else: else:
# call DTensor's propagate_op_sharding to get the prop result # call DTensor's propagate_op_sharding to get the prop result
output_sharding = super().propagate_op_sharding(op_overload, op_schema) output_sharding = super().propagate(op_overload, op_schema)
# update cached table # update cached table
self.cached_prop_results[op_schema] = output_sharding self.cached_prop_results[op_schema] = output_sharding
return output_sharding return output_sharding

View File

@ -30,7 +30,7 @@ __all__ = [
# switch the DTensor propagator to use the caching propagator to speed up # switch the DTensor propagator to use the caching propagator to speed up
# the TP eager execution time. # the TP eager execution time.
DTensor._propagator = _CachingPropagator(DTensor._propagator.op_to_rules) DTensor._propagator = _CachingPropagator(DTensor._propagator)
def parallelize_module( # type: ignore[return] def parallelize_module( # type: ignore[return]
module: nn.Module, module: nn.Module,