[BE] fix typo in torch/distributed/tensor/: childs -> children (#156609)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156609
Approved by: https://github.com/wanchaol, https://github.com/cyyever
ghstack dependencies: #156311
This commit is contained in:
Xuehai Pan 2025-07-09 13:23:54 +08:00 committed by PyTorch MergeBot
parent 4cc8b60d1b
commit ffe11b2bf2
6 changed files with 37 additions and 22 deletions

View File

@ -3,6 +3,7 @@ from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property from functools import cached_property
from typing import Any, Optional, Union from typing import Any, Optional, Union
from typing_extensions import deprecated
import torch import torch
from torch._ops import OpOverload from torch._ops import OpOverload
@ -174,18 +175,32 @@ class TupleStrategy(StrategyType):
then we should return a single OpStrategy instead of a TupleStrategy then we should return a single OpStrategy instead of a TupleStrategy
""" """
def __init__(self, childs: Sequence[StrategyType]) -> None: def __init__(
self,
children: Sequence[StrategyType],
) -> None:
super().__init__() super().__init__()
self.childs: Sequence[StrategyType] = childs self.children: Sequence[StrategyType] = children
@property
@deprecated(
"TupleStrategy.childs is deprecated, use TupleStrategy.children instead.", # codespell:ignore childs
category=FutureWarning,
)
def childs(self) -> Sequence[StrategyType]: # codespell:ignore childs
"""
Alias for children, to maintain backward compatibility.
"""
return self.children
def child_mesh(self, index: int) -> DeviceMesh: def child_mesh(self, index: int) -> DeviceMesh:
op_strategy = self.childs[index] op_strategy = self.children[index]
assert isinstance(op_strategy, OpStrategy) assert isinstance(op_strategy, OpStrategy)
return op_strategy.mesh return op_strategy.mesh
def __str__(self) -> str: def __str__(self) -> str:
child_strategies_str = ", ".join( child_strategies_str = ", ".join(
[f"{str(strat)}" for idx, strat in enumerate(self.childs)] [f"{str(strat)}" for idx, strat in enumerate(self.children)]
) )
return f"TupleStrategy({child_strategies_str})" return f"TupleStrategy({child_strategies_str})"
@ -282,7 +297,7 @@ class OpSchema:
args_schema.append(_pretty_print_spec(arg.strategies[0].output_specs)) args_schema.append(_pretty_print_spec(arg.strategies[0].output_specs))
mesh_shape = arg.mesh_shape mesh_shape = arg.mesh_shape
elif isinstance(arg, TupleStrategy): elif isinstance(arg, TupleStrategy):
first_op_strategy = arg.childs[0] first_op_strategy = arg.children[0]
assert isinstance(first_op_strategy, OpStrategy) assert isinstance(first_op_strategy, OpStrategy)
mesh_shape = first_op_strategy.mesh_shape mesh_shape = first_op_strategy.mesh_shape
args_schema.append(str(arg)) args_schema.append(str(arg))
@ -342,7 +357,7 @@ class OpSchema:
mesh = first_arg.mesh mesh = first_arg.mesh
elif isinstance(first_arg, (list, tuple, TupleStrategy)): elif isinstance(first_arg, (list, tuple, TupleStrategy)):
first_elem = ( first_elem = (
first_arg.childs[0] first_arg.children[0]
if isinstance(first_arg, TupleStrategy) if isinstance(first_arg, TupleStrategy)
else first_arg[0] else first_arg[0]
) )

View File

@ -419,8 +419,8 @@ def foreach_norm_strategy(op_schema: OpSchema) -> TupleStrategy:
assert isinstance(input_tuple_strategy, TupleStrategy) assert isinstance(input_tuple_strategy, TupleStrategy)
norm_type = args_schema[1] if len(args_schema) > 1 else 2 norm_type = args_schema[1] if len(args_schema) > 1 else 2
assert isinstance(norm_type, (int, float, str)), f"{norm_type}" assert isinstance(norm_type, (int, float, str)), f"{norm_type}"
output_tuple_strategy_childs: list[OpStrategy] = [] output_tuple_strategy_children: list[OpStrategy] = []
for op_strategy in input_tuple_strategy.childs: for op_strategy in input_tuple_strategy.children:
assert isinstance(op_strategy, OpStrategy), f"{op_strategy}" assert isinstance(op_strategy, OpStrategy), f"{op_strategy}"
reduce_dims = list(range(op_strategy.ndim)) reduce_dims = list(range(op_strategy.ndim))
output_strategy = common_reduction_strategy( output_strategy = common_reduction_strategy(
@ -429,8 +429,8 @@ def foreach_norm_strategy(op_schema: OpSchema) -> TupleStrategy:
reduction_linear=True, reduction_linear=True,
reduction_op=NormReduction(norm_type), reduction_op=NormReduction(norm_type),
) )
output_tuple_strategy_childs.append(output_strategy) output_tuple_strategy_children.append(output_strategy)
return TupleStrategy(output_tuple_strategy_childs) return TupleStrategy(output_tuple_strategy_children)
@register_op_strategy( @register_op_strategy(

View File

@ -707,12 +707,12 @@ def list_pointwise_strategy(
) -> list[Optional[TupleStrategy]]: ) -> list[Optional[TupleStrategy]]:
first_arg = args_schema[0] first_arg = args_schema[0]
assert isinstance(first_arg, TupleStrategy) assert isinstance(first_arg, TupleStrategy)
strategy_len = len(first_arg.childs) strategy_len = len(first_arg.children)
tuple_strategies: list[Optional[TupleStrategy]] = [] tuple_strategies: list[Optional[TupleStrategy]] = []
for arg_idx, arg in enumerate(args_schema): for arg_idx, arg in enumerate(args_schema):
if isinstance(arg, TupleStrategy): if isinstance(arg, TupleStrategy):
# every tuple strategy should have the same length # every tuple strategy should have the same length
assert len(arg.childs) == strategy_len assert len(arg.children) == strategy_len
tuple_strategies.append(arg) tuple_strategies.append(arg)
elif isinstance(arg, OpStrategy): elif isinstance(arg, OpStrategy):
if arg_idx > 0: # implicitly broadcast if arg_idx > 0: # implicitly broadcast
@ -732,10 +732,10 @@ def list_pointwise_strategy(
follow_strategy: TupleStrategy = not_none(args_strategies[0]) follow_strategy: TupleStrategy = not_none(args_strategies[0])
list_strategy: list[OpStrategy] = [] list_strategy: list[OpStrategy] = []
for child_idx, child_strtgy in enumerate(follow_strategy.childs): for child_idx, child_strtgy in enumerate(follow_strategy.children):
assert isinstance(child_strtgy, OpStrategy) assert isinstance(child_strtgy, OpStrategy)
args_schema: list[Optional[OpStrategy]] = [ args_schema: list[Optional[OpStrategy]] = [
cast(OpStrategy, arg_strategy.childs[child_idx]) if arg_strategy else None cast(OpStrategy, arg_strategy.children[child_idx]) if arg_strategy else None
for arg_strategy in args_strategies for arg_strategy in args_strategies
] ]
pointwise_strategy: OpStrategy = common_pointwise_strategy( pointwise_strategy: OpStrategy = common_pointwise_strategy(

View File

@ -602,7 +602,7 @@ def _derive_follow_placements_from_tuple_strategy(
follow_placements: Optional[list[Placement]] = None follow_placements: Optional[list[Placement]] = None
mesh = tuple_strategy.child_mesh(0) mesh = tuple_strategy.child_mesh(0)
for arg_strategy in tuple_strategy.childs: for arg_strategy in tuple_strategy.children:
assert isinstance(arg_strategy, OpStrategy) assert isinstance(arg_strategy, OpStrategy)
if arg_strategy.mesh != mesh: if arg_strategy.mesh != mesh:
raise ValueError( raise ValueError(
@ -644,7 +644,7 @@ def stack_strategy(op_schema: OpSchema) -> StrategyType:
args_schema = op_schema.args_schema args_schema = op_schema.args_schema
input_tuple_strategy = args_schema[0] input_tuple_strategy = args_schema[0]
assert isinstance(input_tuple_strategy, TupleStrategy), f"{input_tuple_strategy}" assert isinstance(input_tuple_strategy, TupleStrategy), f"{input_tuple_strategy}"
first_input_strategy = input_tuple_strategy.childs[0] first_input_strategy = input_tuple_strategy.children[0]
assert isinstance(first_input_strategy, OpStrategy), f"{first_input_strategy}" assert isinstance(first_input_strategy, OpStrategy), f"{first_input_strategy}"
common_input_ndim = first_input_strategy.ndim common_input_ndim = first_input_strategy.ndim
dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0 dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0
@ -662,7 +662,7 @@ def stack_strategy(op_schema: OpSchema) -> StrategyType:
input_specs = tuple( input_specs = tuple(
DTensorSpec(mesh, tuple(follow_placements)) DTensorSpec(mesh, tuple(follow_placements))
for _ in range(len(input_tuple_strategy.childs)) for _ in range(len(input_tuple_strategy.children))
) )
follow_placements = normalize_shard_for_stack(follow_placements, dim) follow_placements = normalize_shard_for_stack(follow_placements, dim)
@ -681,7 +681,7 @@ def cat_strategy(op_schema: OpSchema) -> StrategyType:
args_schema = op_schema.args_schema args_schema = op_schema.args_schema
input_tuple_strategy = args_schema[0] input_tuple_strategy = args_schema[0]
assert isinstance(input_tuple_strategy, TupleStrategy), f"{input_tuple_strategy}" assert isinstance(input_tuple_strategy, TupleStrategy), f"{input_tuple_strategy}"
first_input_strategy = input_tuple_strategy.childs[0] first_input_strategy = input_tuple_strategy.children[0]
assert isinstance(first_input_strategy, OpStrategy), f"{first_input_strategy}" assert isinstance(first_input_strategy, OpStrategy), f"{first_input_strategy}"
common_input_ndim = first_input_strategy.ndim common_input_ndim = first_input_strategy.ndim
dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0 dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0
@ -701,7 +701,7 @@ def cat_strategy(op_schema: OpSchema) -> StrategyType:
input_specs = tuple( input_specs = tuple(
DTensorSpec(mesh, tuple(follow_placements)) DTensorSpec(mesh, tuple(follow_placements))
for _ in range(len(input_tuple_strategy.childs)) for _ in range(len(input_tuple_strategy.children))
) )
op_strategy.strategies.append( op_strategy.strategies.append(
OpSpec( OpSpec(
@ -765,7 +765,7 @@ def prop_index_put(op_schema: OpSchema) -> StrategyType:
# 1. `indices` should all be replicated first. # 1. `indices` should all be replicated first.
indices_redistribute_costs = [] indices_redistribute_costs = []
new_indices_spec: list[Optional[DTensorSpec]] = [] new_indices_spec: list[Optional[DTensorSpec]] = []
for indices_spec_child in indices_spec.childs: for indices_spec_child in indices_spec.children:
assert isinstance(indices_spec_child, OpStrategy) assert isinstance(indices_spec_child, OpStrategy)
replicated_spec = DTensorSpec( replicated_spec = DTensorSpec(

View File

@ -368,7 +368,7 @@ class ShardingPropagator:
# runtime select OpSpec for each TupleStrategy input arg # runtime select OpSpec for each TupleStrategy input arg
selected_strategies: list[OpSpec] = [] selected_strategies: list[OpSpec] = []
out_spec_list: list[DTensorSpec] = [] out_spec_list: list[DTensorSpec] = []
for strategy in op_strategy.childs: for strategy in op_strategy.children:
assert isinstance(strategy, OpStrategy) assert isinstance(strategy, OpStrategy)
selected_strategy = self._select_strategy(strategy) selected_strategy = self._select_strategy(strategy)
selected_strategies.append(selected_strategy) selected_strategies.append(selected_strategy)

View File

@ -77,7 +77,7 @@ def register_sharding(op: Union[OpOverload, list[OpOverload]]):
# take the output spec from the first strategy # take the output spec from the first strategy
return strategy.strategies[0].output_spec return strategy.strategies[0].output_spec
elif isinstance(strategy, TupleStrategy): elif isinstance(strategy, TupleStrategy):
return tuple(strategy_to_spec(s) for s in strategy.childs) return tuple(strategy_to_spec(s) for s in strategy.children)
else: else:
return strategy return strategy