mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE] fix typo in torch/distributed/tensor/: childs -> children (#156609)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156609 Approved by: https://github.com/wanchaol, https://github.com/cyyever ghstack dependencies: #156311
This commit is contained in:
parent
4cc8b60d1b
commit
ffe11b2bf2
|
|
@ -3,6 +3,7 @@ from collections.abc import Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._ops import OpOverload
|
from torch._ops import OpOverload
|
||||||
|
|
@ -174,18 +175,32 @@ class TupleStrategy(StrategyType):
|
||||||
then we should return a single OpStrategy instead of a TupleStrategy
|
then we should return a single OpStrategy instead of a TupleStrategy
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, childs: Sequence[StrategyType]) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
children: Sequence[StrategyType],
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.childs: Sequence[StrategyType] = childs
|
self.children: Sequence[StrategyType] = children
|
||||||
|
|
||||||
|
@property
|
||||||
|
@deprecated(
|
||||||
|
"TupleStrategy.childs is deprecated, use TupleStrategy.children instead.", # codespell:ignore childs
|
||||||
|
category=FutureWarning,
|
||||||
|
)
|
||||||
|
def childs(self) -> Sequence[StrategyType]: # codespell:ignore childs
|
||||||
|
"""
|
||||||
|
Alias for children, to maintain backward compatibility.
|
||||||
|
"""
|
||||||
|
return self.children
|
||||||
|
|
||||||
def child_mesh(self, index: int) -> DeviceMesh:
|
def child_mesh(self, index: int) -> DeviceMesh:
|
||||||
op_strategy = self.childs[index]
|
op_strategy = self.children[index]
|
||||||
assert isinstance(op_strategy, OpStrategy)
|
assert isinstance(op_strategy, OpStrategy)
|
||||||
return op_strategy.mesh
|
return op_strategy.mesh
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
child_strategies_str = ", ".join(
|
child_strategies_str = ", ".join(
|
||||||
[f"{str(strat)}" for idx, strat in enumerate(self.childs)]
|
[f"{str(strat)}" for idx, strat in enumerate(self.children)]
|
||||||
)
|
)
|
||||||
return f"TupleStrategy({child_strategies_str})"
|
return f"TupleStrategy({child_strategies_str})"
|
||||||
|
|
||||||
|
|
@ -282,7 +297,7 @@ class OpSchema:
|
||||||
args_schema.append(_pretty_print_spec(arg.strategies[0].output_specs))
|
args_schema.append(_pretty_print_spec(arg.strategies[0].output_specs))
|
||||||
mesh_shape = arg.mesh_shape
|
mesh_shape = arg.mesh_shape
|
||||||
elif isinstance(arg, TupleStrategy):
|
elif isinstance(arg, TupleStrategy):
|
||||||
first_op_strategy = arg.childs[0]
|
first_op_strategy = arg.children[0]
|
||||||
assert isinstance(first_op_strategy, OpStrategy)
|
assert isinstance(first_op_strategy, OpStrategy)
|
||||||
mesh_shape = first_op_strategy.mesh_shape
|
mesh_shape = first_op_strategy.mesh_shape
|
||||||
args_schema.append(str(arg))
|
args_schema.append(str(arg))
|
||||||
|
|
@ -342,7 +357,7 @@ class OpSchema:
|
||||||
mesh = first_arg.mesh
|
mesh = first_arg.mesh
|
||||||
elif isinstance(first_arg, (list, tuple, TupleStrategy)):
|
elif isinstance(first_arg, (list, tuple, TupleStrategy)):
|
||||||
first_elem = (
|
first_elem = (
|
||||||
first_arg.childs[0]
|
first_arg.children[0]
|
||||||
if isinstance(first_arg, TupleStrategy)
|
if isinstance(first_arg, TupleStrategy)
|
||||||
else first_arg[0]
|
else first_arg[0]
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -419,8 +419,8 @@ def foreach_norm_strategy(op_schema: OpSchema) -> TupleStrategy:
|
||||||
assert isinstance(input_tuple_strategy, TupleStrategy)
|
assert isinstance(input_tuple_strategy, TupleStrategy)
|
||||||
norm_type = args_schema[1] if len(args_schema) > 1 else 2
|
norm_type = args_schema[1] if len(args_schema) > 1 else 2
|
||||||
assert isinstance(norm_type, (int, float, str)), f"{norm_type}"
|
assert isinstance(norm_type, (int, float, str)), f"{norm_type}"
|
||||||
output_tuple_strategy_childs: list[OpStrategy] = []
|
output_tuple_strategy_children: list[OpStrategy] = []
|
||||||
for op_strategy in input_tuple_strategy.childs:
|
for op_strategy in input_tuple_strategy.children:
|
||||||
assert isinstance(op_strategy, OpStrategy), f"{op_strategy}"
|
assert isinstance(op_strategy, OpStrategy), f"{op_strategy}"
|
||||||
reduce_dims = list(range(op_strategy.ndim))
|
reduce_dims = list(range(op_strategy.ndim))
|
||||||
output_strategy = common_reduction_strategy(
|
output_strategy = common_reduction_strategy(
|
||||||
|
|
@ -429,8 +429,8 @@ def foreach_norm_strategy(op_schema: OpSchema) -> TupleStrategy:
|
||||||
reduction_linear=True,
|
reduction_linear=True,
|
||||||
reduction_op=NormReduction(norm_type),
|
reduction_op=NormReduction(norm_type),
|
||||||
)
|
)
|
||||||
output_tuple_strategy_childs.append(output_strategy)
|
output_tuple_strategy_children.append(output_strategy)
|
||||||
return TupleStrategy(output_tuple_strategy_childs)
|
return TupleStrategy(output_tuple_strategy_children)
|
||||||
|
|
||||||
|
|
||||||
@register_op_strategy(
|
@register_op_strategy(
|
||||||
|
|
|
||||||
|
|
@ -707,12 +707,12 @@ def list_pointwise_strategy(
|
||||||
) -> list[Optional[TupleStrategy]]:
|
) -> list[Optional[TupleStrategy]]:
|
||||||
first_arg = args_schema[0]
|
first_arg = args_schema[0]
|
||||||
assert isinstance(first_arg, TupleStrategy)
|
assert isinstance(first_arg, TupleStrategy)
|
||||||
strategy_len = len(first_arg.childs)
|
strategy_len = len(first_arg.children)
|
||||||
tuple_strategies: list[Optional[TupleStrategy]] = []
|
tuple_strategies: list[Optional[TupleStrategy]] = []
|
||||||
for arg_idx, arg in enumerate(args_schema):
|
for arg_idx, arg in enumerate(args_schema):
|
||||||
if isinstance(arg, TupleStrategy):
|
if isinstance(arg, TupleStrategy):
|
||||||
# every tuple strategy should have the same length
|
# every tuple strategy should have the same length
|
||||||
assert len(arg.childs) == strategy_len
|
assert len(arg.children) == strategy_len
|
||||||
tuple_strategies.append(arg)
|
tuple_strategies.append(arg)
|
||||||
elif isinstance(arg, OpStrategy):
|
elif isinstance(arg, OpStrategy):
|
||||||
if arg_idx > 0: # implicitly broadcast
|
if arg_idx > 0: # implicitly broadcast
|
||||||
|
|
@ -732,10 +732,10 @@ def list_pointwise_strategy(
|
||||||
follow_strategy: TupleStrategy = not_none(args_strategies[0])
|
follow_strategy: TupleStrategy = not_none(args_strategies[0])
|
||||||
list_strategy: list[OpStrategy] = []
|
list_strategy: list[OpStrategy] = []
|
||||||
|
|
||||||
for child_idx, child_strtgy in enumerate(follow_strategy.childs):
|
for child_idx, child_strtgy in enumerate(follow_strategy.children):
|
||||||
assert isinstance(child_strtgy, OpStrategy)
|
assert isinstance(child_strtgy, OpStrategy)
|
||||||
args_schema: list[Optional[OpStrategy]] = [
|
args_schema: list[Optional[OpStrategy]] = [
|
||||||
cast(OpStrategy, arg_strategy.childs[child_idx]) if arg_strategy else None
|
cast(OpStrategy, arg_strategy.children[child_idx]) if arg_strategy else None
|
||||||
for arg_strategy in args_strategies
|
for arg_strategy in args_strategies
|
||||||
]
|
]
|
||||||
pointwise_strategy: OpStrategy = common_pointwise_strategy(
|
pointwise_strategy: OpStrategy = common_pointwise_strategy(
|
||||||
|
|
|
||||||
|
|
@ -602,7 +602,7 @@ def _derive_follow_placements_from_tuple_strategy(
|
||||||
|
|
||||||
follow_placements: Optional[list[Placement]] = None
|
follow_placements: Optional[list[Placement]] = None
|
||||||
mesh = tuple_strategy.child_mesh(0)
|
mesh = tuple_strategy.child_mesh(0)
|
||||||
for arg_strategy in tuple_strategy.childs:
|
for arg_strategy in tuple_strategy.children:
|
||||||
assert isinstance(arg_strategy, OpStrategy)
|
assert isinstance(arg_strategy, OpStrategy)
|
||||||
if arg_strategy.mesh != mesh:
|
if arg_strategy.mesh != mesh:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
@ -644,7 +644,7 @@ def stack_strategy(op_schema: OpSchema) -> StrategyType:
|
||||||
args_schema = op_schema.args_schema
|
args_schema = op_schema.args_schema
|
||||||
input_tuple_strategy = args_schema[0]
|
input_tuple_strategy = args_schema[0]
|
||||||
assert isinstance(input_tuple_strategy, TupleStrategy), f"{input_tuple_strategy}"
|
assert isinstance(input_tuple_strategy, TupleStrategy), f"{input_tuple_strategy}"
|
||||||
first_input_strategy = input_tuple_strategy.childs[0]
|
first_input_strategy = input_tuple_strategy.children[0]
|
||||||
assert isinstance(first_input_strategy, OpStrategy), f"{first_input_strategy}"
|
assert isinstance(first_input_strategy, OpStrategy), f"{first_input_strategy}"
|
||||||
common_input_ndim = first_input_strategy.ndim
|
common_input_ndim = first_input_strategy.ndim
|
||||||
dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0
|
dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0
|
||||||
|
|
@ -662,7 +662,7 @@ def stack_strategy(op_schema: OpSchema) -> StrategyType:
|
||||||
|
|
||||||
input_specs = tuple(
|
input_specs = tuple(
|
||||||
DTensorSpec(mesh, tuple(follow_placements))
|
DTensorSpec(mesh, tuple(follow_placements))
|
||||||
for _ in range(len(input_tuple_strategy.childs))
|
for _ in range(len(input_tuple_strategy.children))
|
||||||
)
|
)
|
||||||
|
|
||||||
follow_placements = normalize_shard_for_stack(follow_placements, dim)
|
follow_placements = normalize_shard_for_stack(follow_placements, dim)
|
||||||
|
|
@ -681,7 +681,7 @@ def cat_strategy(op_schema: OpSchema) -> StrategyType:
|
||||||
args_schema = op_schema.args_schema
|
args_schema = op_schema.args_schema
|
||||||
input_tuple_strategy = args_schema[0]
|
input_tuple_strategy = args_schema[0]
|
||||||
assert isinstance(input_tuple_strategy, TupleStrategy), f"{input_tuple_strategy}"
|
assert isinstance(input_tuple_strategy, TupleStrategy), f"{input_tuple_strategy}"
|
||||||
first_input_strategy = input_tuple_strategy.childs[0]
|
first_input_strategy = input_tuple_strategy.children[0]
|
||||||
assert isinstance(first_input_strategy, OpStrategy), f"{first_input_strategy}"
|
assert isinstance(first_input_strategy, OpStrategy), f"{first_input_strategy}"
|
||||||
common_input_ndim = first_input_strategy.ndim
|
common_input_ndim = first_input_strategy.ndim
|
||||||
dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0
|
dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0
|
||||||
|
|
@ -701,7 +701,7 @@ def cat_strategy(op_schema: OpSchema) -> StrategyType:
|
||||||
|
|
||||||
input_specs = tuple(
|
input_specs = tuple(
|
||||||
DTensorSpec(mesh, tuple(follow_placements))
|
DTensorSpec(mesh, tuple(follow_placements))
|
||||||
for _ in range(len(input_tuple_strategy.childs))
|
for _ in range(len(input_tuple_strategy.children))
|
||||||
)
|
)
|
||||||
op_strategy.strategies.append(
|
op_strategy.strategies.append(
|
||||||
OpSpec(
|
OpSpec(
|
||||||
|
|
@ -765,7 +765,7 @@ def prop_index_put(op_schema: OpSchema) -> StrategyType:
|
||||||
# 1. `indices` should all be replicated first.
|
# 1. `indices` should all be replicated first.
|
||||||
indices_redistribute_costs = []
|
indices_redistribute_costs = []
|
||||||
new_indices_spec: list[Optional[DTensorSpec]] = []
|
new_indices_spec: list[Optional[DTensorSpec]] = []
|
||||||
for indices_spec_child in indices_spec.childs:
|
for indices_spec_child in indices_spec.children:
|
||||||
assert isinstance(indices_spec_child, OpStrategy)
|
assert isinstance(indices_spec_child, OpStrategy)
|
||||||
|
|
||||||
replicated_spec = DTensorSpec(
|
replicated_spec = DTensorSpec(
|
||||||
|
|
|
||||||
|
|
@ -368,7 +368,7 @@ class ShardingPropagator:
|
||||||
# runtime select OpSpec for each TupleStrategy input arg
|
# runtime select OpSpec for each TupleStrategy input arg
|
||||||
selected_strategies: list[OpSpec] = []
|
selected_strategies: list[OpSpec] = []
|
||||||
out_spec_list: list[DTensorSpec] = []
|
out_spec_list: list[DTensorSpec] = []
|
||||||
for strategy in op_strategy.childs:
|
for strategy in op_strategy.children:
|
||||||
assert isinstance(strategy, OpStrategy)
|
assert isinstance(strategy, OpStrategy)
|
||||||
selected_strategy = self._select_strategy(strategy)
|
selected_strategy = self._select_strategy(strategy)
|
||||||
selected_strategies.append(selected_strategy)
|
selected_strategies.append(selected_strategy)
|
||||||
|
|
|
||||||
|
|
@ -77,7 +77,7 @@ def register_sharding(op: Union[OpOverload, list[OpOverload]]):
|
||||||
# take the output spec from the first strategy
|
# take the output spec from the first strategy
|
||||||
return strategy.strategies[0].output_spec
|
return strategy.strategies[0].output_spec
|
||||||
elif isinstance(strategy, TupleStrategy):
|
elif isinstance(strategy, TupleStrategy):
|
||||||
return tuple(strategy_to_spec(s) for s in strategy.childs)
|
return tuple(strategy_to_spec(s) for s in strategy.children)
|
||||||
else:
|
else:
|
||||||
return strategy
|
return strategy
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user