diff --git a/torch/distributed/run.py b/torch/distributed/run.py index ba8fd3dc69e..dbcae1a641a 100644 --- a/torch/distributed/run.py +++ b/torch/distributed/run.py @@ -770,14 +770,9 @@ def _get_logs_specs_class(logs_specs_name: Optional[str]) -> type[LogsSpecs]: logs_specs_cls = None if logs_specs_name is not None: eps = metadata.entry_points() - if hasattr(eps, "select"): # >= 3.10 - group = eps.select(group="torchrun.logs_specs") - if group.select(name=logs_specs_name): - logs_specs_cls = group[logs_specs_name].load() - - elif specs := eps.get("torchrun.logs_specs"): # < 3.10 - if entrypoint_list := [ep for ep in specs if ep.name == logs_specs_name]: - logs_specs_cls = entrypoint_list[0].load() + group = eps.select(group="torchrun.logs_specs") + if group.select(name=logs_specs_name): + logs_specs_cls = group[logs_specs_name].load() if logs_specs_cls is None: raise ValueError( diff --git a/torch/distributed/tensor/placement_types.py b/torch/distributed/tensor/placement_types.py index 704efe3e859..ad304229a27 100644 --- a/torch/distributed/tensor/placement_types.py +++ b/torch/distributed/tensor/placement_types.py @@ -368,11 +368,7 @@ class Shard(Placement): return f"S({self.dim})" -# kw_only is only available in python >= 3.10 -kw_only_dataclass = dict(kw_only=True) if "kw_only" in dataclass.__kwdefaults__ else {} - - -@dataclass(frozen=True, **kw_only_dataclass) +@dataclass(frozen=True, kw_only=True) class _StridedShard(Shard): """ _StridedShard is only introduced to support 2D FSDP2 + TP sharding where the tensor