diff --git a/.ci/pytorch/macos-test.sh b/.ci/pytorch/macos-test.sh index 64ea8a1c255..79d47da4317 100755 --- a/.ci/pytorch/macos-test.sh +++ b/.ci/pytorch/macos-test.sh @@ -13,6 +13,8 @@ if [[ ! $(python -c "import torch; print(int(torch.backends.openmp.is_available( fi popd +python -mpip install -r requirements.txt + # enable debug asserts in serialization export TORCH_SERIALIZATION_DEBUG=1 diff --git a/test/distributed/tensor/test_fake.py b/test/distributed/tensor/test_fake.py new file mode 100644 index 00000000000..099c6e87f5f --- /dev/null +++ b/test/distributed/tensor/test_fake.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +import torch +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import Shard +from torch.testing._internal.common_utils import run_tests, TestCase +from torch.testing._internal.distributed.fake_pg import FakeStore + + +class TestFakeDTensor(TestCase): + def test_fake_dtensor_operations(self): + # Use FakeTensorMode to handle CUDA tensors without actual CUDA + fake_mode = FakeTensorMode() + world_size = 4 + + fake_store = FakeStore() + torch.distributed.init_process_group( + "fake", store=fake_store, rank=0, world_size=world_size + ) + device_mesh = torch.distributed.device_mesh.init_device_mesh( + "cuda", + (2, world_size // 2), + ) + + # Create fake CUDA tensor using FakeTensorMode + with fake_mode: + x = torch.randn(1, 1, device="cuda") + x = DTensor.from_local(x, device_mesh, [Shard(0), Shard(1)]) + + # Test basic DTensor operations + self.assertIsInstance(x, DTensor) + + # Test sum operation + r = x.sum(1) + self.assertIsInstance(r, DTensor) + + +if __name__ == "__main__": + run_tests() diff --git a/test/test_numa_binding.py b/test/test_numa_binding.py index 764156ff9b9..d38032ba226 100644 --- a/test/test_numa_binding.py +++ b/test/test_numa_binding.py @@ -7,7 +7,7 @@ import sys from dataclasses import dataclass from multiprocessing.context import SpawnProcess from typing import Any, Optional -from unittest import skipUnless +from unittest import skipIf, skipUnless from unittest.mock import mock_open, patch import torch @@ -22,7 +22,7 @@ from torch.numa.binding import ( AffinityMode, NumaOptions, ) -from torch.testing._internal.common_utils import run_tests, TestCase +from torch.testing._internal.common_utils import IS_MACOS, run_tests, TestCase @dataclass(frozen=True) @@ -680,6 +680,7 @@ class NumaBindingTest(TestCase): set(range(0, 2)), ) + @skipIf(IS_MACOS, "sched_getaffinity doesn't exist") def test_binds_to_node_0_if_node_stored_as_minus_one(self) -> None: self._add_mock_hardware( num_sockets=1, diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index ad3d8e3abf2..79e437063b8 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -851,3 +851,12 @@ class ProcessGroupXCCL(Backend): def _set_process_group(pg: ProcessGroup) -> None: ... def _current_process_group() -> ProcessGroup: ... +def _dump_nccl_trace_json( + includeCollectives: Optional[bool] = ..., + onlyActive: Optional[bool] = ..., +) -> bytes: ... +def _dump_nccl_trace( + includeCollectives: Optional[bool] = ..., + includeStackTraces: Optional[bool] = ..., + onlyActive: Optional[bool] = ..., +) -> bytes: ... diff --git a/torch/distributed/_C_stubs.py b/torch/distributed/_C_stubs.py new file mode 100644 index 00000000000..b241006372b --- /dev/null +++ b/torch/distributed/_C_stubs.py @@ -0,0 +1,150 @@ +# mypy: allow-untyped-defs +""" +Python stubs for backend-specific distributed components. + +Since _C._distributed_c10d always exists now, this module only provides +stubs for backend-specific functionality that may not be available in all builds +(e.g., NCCL, UCC, MPI, Gloo, etc.). +""" + +from __future__ import annotations + +from typing import Optional, TYPE_CHECKING + +from torch._C._distributed_c10d import Store + + +if TYPE_CHECKING: + from datetime import timedelta + +import torch + + +# Store classes +class HashStore(Store): + """Stub HashStore for builds without this functionality.""" + + def __init__(self, *args, **kwargs): + self._data = {} + + def set(self, key: str, value: str): + self._data[key] = value + + def get(self, key: str) -> bytes: + return self._data.get(key, "").encode() + + +# Backend-specific process group stubs +class ProcessGroupMPI: + """Stub ProcessGroupMPI for non-MPI builds.""" + + def __init__(self, *args, **kwargs): + pass + + +class ProcessGroupNCCL: + """Stub ProcessGroupNCCL for non-NCCL builds.""" + + def __init__(self, *args, **kwargs): + pass + + +class ProcessGroupGloo: + """Stub ProcessGroupGloo for non-Gloo builds.""" + + def __init__(self, *args, **kwargs): + pass + + +class ProcessGroupUCC: + """Stub ProcessGroupUCC for non-UCC builds.""" + + def __init__(self, *args, **kwargs): + pass + + +class ProcessGroupXCCL: + """Stub ProcessGroupXCCL for non-XCCL builds.""" + + def __init__(self, *args, **kwargs): + pass + + +class _ProcessGroupWrapper: + """Stub _ProcessGroupWrapper for non-Gloo builds.""" + + def __init__(self, process_group, *args, **kwargs): + self._process_group = process_group + + def __getattr__(self, name): + return getattr(self._process_group, name) + + +# NCCL-specific function stubs +_DEFAULT_PG_NCCL_TIMEOUT: Optional[timedelta] = None + + +def _hash_tensors(tensors): + """Stub function to hash tensors - returns dummy hash.""" + return 0 + + +def _dump_nccl_trace_json( + includeCollectives: Optional[bool] = None, onlyActive: Optional[bool] = None +) -> bytes: + """Stub function that returns empty JSON trace.""" + return b"{}" + + +def _dump_nccl_trace( + includeCollectives: Optional[bool] = None, + includeStackTraces: Optional[bool] = None, + onlyActive: Optional[bool] = None, +) -> bytes: + """Stub function that returns empty pickle trace.""" + return b"" + + +# NVSHMEM/SymmetricMemory stubs +def _is_nvshmem_available() -> bool: + """Stub function that returns False indicating NVSHMEM is not available.""" + return False + + +def _nvshmemx_cumodule_init(module: int) -> None: + """Stub function for NVSHMEM CU module initialization.""" + + +class _SymmetricMemory: + """Stub _SymmetricMemory class for builds without this functionality.""" + + def __init__(self, *args, **kwargs): + pass + + @classmethod + def empty_strided_p2p(cls, size, stride, dtype, device, group_name=None): + """Stub that returns a regular tensor.""" + return torch.empty(size, dtype=dtype, device=device) + + @classmethod + def rendezvous(cls, tensor, group_name=None): + """Stub that returns None.""" + return None + + @classmethod + def set_group_info(cls, *args, **kwargs): + """Stub that does nothing.""" + + @classmethod + def set_backend(cls, name): + """Stub that does nothing.""" + + @classmethod + def get_backend(cls, device): + """Stub that returns None.""" + return None + + @classmethod + def has_multicast_support(cls, device_type, device_index): + """Stub that returns False.""" + return False diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index bfb4175d61e..836b00c51c3 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -30,132 +30,124 @@ DistNetworkError = torch._C._DistNetworkError DistStoreError = torch._C._DistStoreError QueueEmptyError = torch._C._DistQueueEmptyError -if is_available(): - from torch._C._distributed_c10d import ( - _broadcast_coalesced, - _compute_bucket_assignment_by_size, - _ControlCollectives, - _DEFAULT_FIRST_BUCKET_BYTES, - _make_nccl_premul_sum, - _register_builtin_comm_hook, - _register_comm_hook, - _StoreCollectives, - _test_python_store, - _verify_params_across_processes, - Backend as _Backend, - BuiltinCommHookType, - DebugLevel, - FileStore, - get_debug_level, - GradBucket, - Logger, - PrefixStore, - ProcessGroup as ProcessGroup, - Reducer, - set_debug_level, - set_debug_level_from_env, - Store, - TCPStore, - Work as _Work, - ) +from torch.distributed._distributed_c10d import ( + _broadcast_coalesced, + _compute_bucket_assignment_by_size, + _ControlCollectives, + _DEFAULT_FIRST_BUCKET_BYTES, + _make_nccl_premul_sum, + _register_builtin_comm_hook, + _register_comm_hook, + _StoreCollectives, + _test_python_store, + _verify_params_across_processes, + Backend as _Backend, + BuiltinCommHookType, + DebugLevel, + FileStore, + get_debug_level, + GradBucket, + Logger, + PrefixStore, + ProcessGroup as ProcessGroup, + Reducer, + set_debug_level, + set_debug_level_from_env, + Store, + TCPStore, + Work as _Work, +) - class _DistributedPdb(pdb.Pdb): - """ - Supports using PDB from inside a multiprocessing child process. - Usage: - _DistributedPdb().set_trace() - """ +class _DistributedPdb(pdb.Pdb): + """ + Supports using PDB from inside a multiprocessing child process. - def interaction(self, *args, **kwargs): - _stdin = sys.stdin - try: - sys.stdin = open("/dev/stdin") - pdb.Pdb.interaction(self, *args, **kwargs) - finally: - sys.stdin = _stdin + Usage: + _DistributedPdb().set_trace() + """ - _breakpoint_cache: dict[int, typing.Any] = {} - - def breakpoint(rank: int = 0, skip: int = 0, timeout_s=3600): - """ - Set a breakpoint, but only on a single rank. All other ranks will wait for you to be - done with the breakpoint before continuing. - - Args: - rank (int): Which rank to break on. Default: ``0`` - skip (int): Skip the first ``skip`` calls to this breakpoint. Default: ``0``. - """ - if skip > 0: - key = hash(str(traceback.format_exc())) - counter = _breakpoint_cache.get(key, 0) + 1 - _breakpoint_cache[key] = counter - if counter <= skip: - log.warning("Skip the breakpoint, counter=%d", counter) - return - - # avoid having the default timeout (if short) interrupt your debug session - if timeout_s is not None: - for group in torch.distributed.distributed_c10d._pg_map: - torch.distributed.distributed_c10d._set_pg_timeout( - timedelta(seconds=timeout_s), group - ) - - if get_rank() == rank: - pdb = _DistributedPdb() - pdb.message( - "\n!!! ATTENTION !!!\n\n" - f"Type 'up' to get to the frame that called dist.breakpoint(rank={rank})\n" - ) - pdb.set_trace() - # If Meta/Python keys are in the TLS, we want to make sure that we ignore them - # and hit the (default) CPU/CUDA implementation of barrier. - meta_in_tls = torch._C._meta_in_tls_dispatch_include() - guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined] - torch._C._set_meta_in_tls_dispatch_include(False) + def interaction(self, *args, **kwargs): + _stdin = sys.stdin try: - barrier() + sys.stdin = open("/dev/stdin") + pdb.Pdb.interaction(self, *args, **kwargs) finally: - torch._C._set_meta_in_tls_dispatch_include(meta_in_tls) - del guard + sys.stdin = _stdin - if sys.platform != "win32": - from torch._C._distributed_c10d import HashStore - from .device_mesh import DeviceMesh, init_device_mesh +_breakpoint_cache: dict[int, typing.Any] = {} - # Variables prefixed with underscore are not auto imported - # See the comment in `distributed_c10d.py` above `_backend` on why we expose - # this. - from .distributed_c10d import * # noqa: F403 - from .distributed_c10d import ( - _all_gather_base, - _coalescing_manager, - _CoalescingManager, - _create_process_group_wrapper, - _get_process_group_name, - _rank_not_in_group, - _reduce_scatter_base, - _time_estimator, - get_node_local_rank, - ) - from .remote_device import _remote_device - from .rendezvous import ( - _create_store_from_options, - register_rendezvous_handler, - rendezvous, - ) - set_debug_level_from_env() +def breakpoint(rank: int = 0, skip: int = 0, timeout_s=3600): + """ + Set a breakpoint, but only on a single rank. All other ranks will wait for you to be + done with the breakpoint before continuing. -else: - # This stub is sufficient to get - # python test/test_public_bindings.py -k test_correct_module_names - # working even when USE_DISTRIBUTED=0. Feel free to add more - # stubs as necessary. - # We cannot define stubs directly because they confuse pyre + Args: + rank (int): Which rank to break on. Default: ``0`` + skip (int): Skip the first ``skip`` calls to this breakpoint. Default: ``0``. + """ + if skip > 0: + key = hash(str(traceback.format_exc())) + counter = _breakpoint_cache.get(key, 0) + 1 + _breakpoint_cache[key] = counter + if counter <= skip: + log.warning("Skip the breakpoint, counter=%d", counter) + return - class _ProcessGroupStub: - pass + # avoid having the default timeout (if short) interrupt your debug session + if timeout_s is not None: + for group in torch.distributed.distributed_c10d._pg_map: + torch.distributed.distributed_c10d._set_pg_timeout( + timedelta(seconds=timeout_s), group + ) - sys.modules["torch.distributed"].ProcessGroup = _ProcessGroupStub # type: ignore[attr-defined] + if get_rank() == rank: + pdb = _DistributedPdb() + pdb.message( + "\n!!! ATTENTION !!!\n\n" + f"Type 'up' to get to the frame that called dist.breakpoint(rank={rank})\n" + ) + pdb.set_trace() + # If Meta/Python keys are in the TLS, we want to make sure that we ignore them + # and hit the (default) CPU/CUDA implementation of barrier. + meta_in_tls = torch._C._meta_in_tls_dispatch_include() + guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined] + torch._C._set_meta_in_tls_dispatch_include(False) + try: + barrier() + finally: + torch._C._set_meta_in_tls_dispatch_include(meta_in_tls) + del guard + + +if sys.platform != "win32": + from torch.distributed._distributed_c10d import HashStore + +from .device_mesh import DeviceMesh, init_device_mesh + +# Variables prefixed with underscore are not auto imported +# See the comment in `distributed_c10d.py` above `_backend` on why we expose +# this. +from .distributed_c10d import * # noqa: F403 +from .distributed_c10d import ( + _all_gather_base, + _coalescing_manager, + _CoalescingManager, + _create_process_group_wrapper, + _get_process_group_name, + _rank_not_in_group, + _reduce_scatter_base, + _time_estimator, + get_node_local_rank, +) +from .remote_device import _remote_device +from .rendezvous import ( + _create_store_from_options, + register_rendezvous_handler, + rendezvous, +) + + +set_debug_level_from_env() diff --git a/torch/distributed/_dist2.py b/torch/distributed/_dist2.py index ce5cb8d7e0c..1c27bf55d68 100644 --- a/torch/distributed/_dist2.py +++ b/torch/distributed/_dist2.py @@ -10,7 +10,7 @@ from datetime import timedelta from typing import Protocol, Union import torch -from torch._C._distributed_c10d import ( +from torch.distributed._distributed_c10d import ( _current_process_group, _set_process_group, ProcessGroup, diff --git a/torch/distributed/_distributed_c10d.py b/torch/distributed/_distributed_c10d.py new file mode 100644 index 00000000000..f67ab1f999c --- /dev/null +++ b/torch/distributed/_distributed_c10d.py @@ -0,0 +1,238 @@ +# mypy: disable-error-code="assignment" +# noqa: F401 +""" +Centralized module for importing and re-exporting torch._C._distributed_c10d components. + +IMPORTANT PATTERN: +Never access torch._C._distributed_c10d directly in code. Always import from and use +torch.distributed._distributed_c10d which is guaranteed to have all functions available. + +Example: + # WRONG: torch._C._distributed_c10d._set_global_rank(rank) + # RIGHT: + from torch.distributed._distributed_c10d import _set_global_rank + _set_global_rank(rank) +""" + +from typing import TYPE_CHECKING + +# Import all core distributed components from the C extension +# NB: This list has to be spelled out because the _C module doesn't have __all__ +from torch._C._distributed_c10d import ( + _allow_inflight_collective_as_graph_input, + _broadcast_coalesced, + _compute_bucket_assignment_by_size, + _ControlCollectives, + _current_process_group, + _DEFAULT_FIRST_BUCKET_BYTES, + _DEFAULT_PG_TIMEOUT, + _DistributedBackendOptions, + _make_nccl_premul_sum, + _register_builtin_comm_hook, + _register_comm_hook, + _register_process_group, + _register_work, + _resolve_process_group, + _set_allow_inflight_collective_as_graph_input, + _set_global_rank, + _set_process_group, + _StoreCollectives, + _test_python_store, + _unregister_all_process_groups, + _unregister_process_group, + _verify_params_across_processes, + _WorkerServer, + AllgatherOptions, + AllreduceCoalescedOptions, + AllreduceOptions, + AllToAllOptions, + Backend, + BarrierOptions, + BroadcastOptions, + BuiltinCommHookType, + DebugLevel, + FakeProcessGroup, + FakeWork, + FileStore, + GatherOptions, + get_debug_level, + GradBucket, + Logger, + PrefixStore, + ProcessGroup, + ReduceOp, + ReduceOptions, + Reducer, + ReduceScatterOptions, + ScatterOptions, + set_debug_level, + set_debug_level_from_env, + Store, + TCPStore, + Work, +) + + +# Backend-specific components that may not be available +_MPI_AVAILABLE = False +_NCCL_AVAILABLE = False +_GLOO_AVAILABLE = False +_UCC_AVAILABLE = False +_XCCL_AVAILABLE = False + +# HashStore +try: + from torch._C._distributed_c10d import HashStore +except ImportError: + if not TYPE_CHECKING: + from torch.distributed._C_stubs import HashStore + +# NVSHMEM/SymmetricMemory components +try: + from torch._C._distributed_c10d import ( + _is_nvshmem_available, + _nvshmemx_cumodule_init, + _SymmetricMemory, + ) +except ImportError: + if not TYPE_CHECKING: + from torch.distributed._C_stubs import ( + _is_nvshmem_available, + _nvshmemx_cumodule_init, + _SymmetricMemory, + ) + +# MPI backend +try: + from torch._C._distributed_c10d import ProcessGroupMPI + + _MPI_AVAILABLE = True +except ImportError: + if not TYPE_CHECKING: + from torch.distributed._C_stubs import ProcessGroupMPI + +# NCCL backend +try: + from torch._C._distributed_c10d import ( + _DEFAULT_PG_NCCL_TIMEOUT, + _dump_nccl_trace, + _dump_nccl_trace_json, + _hash_tensors, + ProcessGroupNCCL, + ) + + _NCCL_AVAILABLE = True +except ImportError: + if not TYPE_CHECKING: + from torch.distributed._C_stubs import ( + _DEFAULT_PG_NCCL_TIMEOUT, + _dump_nccl_trace, + _dump_nccl_trace_json, + _hash_tensors, + ProcessGroupNCCL, + ) + +# Gloo backend +try: + from torch._C._distributed_c10d import _ProcessGroupWrapper, ProcessGroupGloo + + _GLOO_AVAILABLE = True +except ImportError: + if not TYPE_CHECKING: + from torch.distributed._C_stubs import _ProcessGroupWrapper, ProcessGroupGloo + +# UCC backend +try: + from torch._C._distributed_c10d import ProcessGroupUCC + + _UCC_AVAILABLE = True +except ImportError: + if not TYPE_CHECKING: + from torch.distributed._C_stubs import ProcessGroupUCC + +# XCCL backend +try: + from torch._C._distributed_c10d import ProcessGroupXCCL + + _XCCL_AVAILABLE = True +except ImportError: + if not TYPE_CHECKING: + from torch.distributed._C_stubs import ProcessGroupXCCL + +# Provide backwards compatibility by making all symbols available at module level +__all__ = [ + # Basic components + "_broadcast_coalesced", + "_compute_bucket_assignment_by_size", + "_ControlCollectives", + "_DEFAULT_FIRST_BUCKET_BYTES", + "_DEFAULT_PG_TIMEOUT", + "_DEFAULT_PG_NCCL_TIMEOUT", + "_make_nccl_premul_sum", + "_register_builtin_comm_hook", + "_register_comm_hook", + "_StoreCollectives", + "_test_python_store", + "_verify_params_across_processes", + "_allow_inflight_collective_as_graph_input", + "_register_work", + "_set_allow_inflight_collective_as_graph_input", + "_is_nvshmem_available", + "_nvshmemx_cumodule_init", + "_SymmetricMemory", + "_hash_tensors", + "_set_global_rank", + "_dump_nccl_trace", + "_dump_nccl_trace_json", + "Backend", + "BuiltinCommHookType", + "DebugLevel", + "FakeProcessGroup", + "FileStore", + "get_debug_level", + "GradBucket", + "HashStore", + "Logger", + "PrefixStore", + "ProcessGroup", + "Reducer", + "ReduceOp", + "set_debug_level", + "set_debug_level_from_env", + "Store", + "TCPStore", + "Work", + "FakeWork", + # Additional distributed_c10d components + "_DistributedBackendOptions", + "_register_process_group", + "_resolve_process_group", + "_unregister_all_process_groups", + "_unregister_process_group", + "_current_process_group", + "_set_process_group", + "_WorkerServer", + "AllgatherOptions", + "AllreduceCoalescedOptions", + "AllreduceOptions", + "AllToAllOptions", + "BarrierOptions", + "BroadcastOptions", + "GatherOptions", + "ReduceOptions", + "ReduceScatterOptions", + "ScatterOptions", + # Process group implementations + "ProcessGroupMPI", + "ProcessGroupNCCL", + "ProcessGroupGloo", + "ProcessGroupUCC", + "ProcessGroupXCCL", + "_ProcessGroupWrapper", + # Availability flags + "_MPI_AVAILABLE", + "_NCCL_AVAILABLE", + "_GLOO_AVAILABLE", + "_UCC_AVAILABLE", + "_XCCL_AVAILABLE", +] diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index 0b53da3988b..eb6a431f69a 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -7,6 +7,10 @@ from typing import Any, cast, Optional, TYPE_CHECKING, Union import torch import torch.distributed as dist import torch.distributed.distributed_c10d as c10d +from torch.distributed._distributed_c10d import ( + _allow_inflight_collective_as_graph_input, + _set_allow_inflight_collective_as_graph_input, +) from torch.distributed.device_mesh import DeviceMesh from torch.fx.experimental.proxy_tensor import get_proxy_mode @@ -853,15 +857,13 @@ def allow_inflight_collective_as_graph_input_ctx(value: bool = True): will be registered in the work registry, and the wait_tensor() in compiled region called on the output tensor of the collective will wait on the correct work object. """ - previous = torch._C._distributed_c10d._allow_inflight_collective_as_graph_input() + previous = _allow_inflight_collective_as_graph_input() try: - torch._C._distributed_c10d._set_allow_inflight_collective_as_graph_input(value) + _set_allow_inflight_collective_as_graph_input(value) yield finally: - torch._C._distributed_c10d._set_allow_inflight_collective_as_graph_input( - previous - ) + _set_allow_inflight_collective_as_graph_input(previous) def _make_all_gather_out_tensor(input, group_size): diff --git a/torch/distributed/_shard/sharded_tensor/reshard.py b/torch/distributed/_shard/sharded_tensor/reshard.py index daef9c35861..2bc3d65e5c8 100644 --- a/torch/distributed/_shard/sharded_tensor/reshard.py +++ b/torch/distributed/_shard/sharded_tensor/reshard.py @@ -4,7 +4,7 @@ import copy import torch import torch.distributed as dist import torch.distributed._shard.sharding_spec as shard_spec -from torch._C._distributed_c10d import ProcessGroup +from torch.distributed._distributed_c10d import ProcessGroup from torch.distributed._shard.metadata import ShardMetadata from torch.distributed._shard.sharding_spec._internals import ( get_chunked_dim_size, diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py index 61808d0adf6..f02563619d2 100644 --- a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py +++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py @@ -4,7 +4,7 @@ from typing import cast import torch import torch.distributed as dist -from torch._C._distributed_c10d import ReduceOp +from torch.distributed._distributed_c10d import ReduceOp from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed._shard.sharding_spec import ChunkShardingSpec from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op diff --git a/torch/distributed/_symmetric_memory/__init__.py b/torch/distributed/_symmetric_memory/__init__.py index 43c2959fdd8..8154cd98091 100644 --- a/torch/distributed/_symmetric_memory/__init__.py +++ b/torch/distributed/_symmetric_memory/__init__.py @@ -15,7 +15,12 @@ import torch import torch.distributed._functional_collectives as funcol import torch.distributed.distributed_c10d as c10d from torch._C._autograd import DeviceType -from torch._C._distributed_c10d import _SymmetricMemory, Work as _Work +from torch.distributed._distributed_c10d import ( + _register_work, + _SymmetricMemory, + ProcessGroup, + Work as _Work, +) _group_name_to_store: dict[str, c10d.Store] = {} @@ -1488,7 +1493,7 @@ def _low_contention_all_gather( src_buf = symm_mem.get_buffer(remote_rank, tensor.shape, tensor.dtype) chunks[remote_rank].copy_(src_buf) symm_mem.barrier() - torch._C._distributed_c10d._register_work(output, Work()) + _register_work(output, Work()) return output @@ -1536,7 +1541,7 @@ def _low_contention_reduce_scatter_with_symm_mem_input( ret = ret.mean(dim=0) else: raise ValueError(f"reduce_op ({reduce_op}) is not supported") - torch._C._distributed_c10d._register_work(ret, Work()) + _register_work(ret, Work()) return ret @@ -1571,7 +1576,7 @@ def _low_contention_reduce_scatter_with_workspace( ret = ret.mean(dim=0) else: raise ValueError(f"reduce_op ({reduce_op}) is not supported") - torch._C._distributed_c10d._register_work(ret, Work()) + _register_work(ret, Work()) return ret @@ -1649,7 +1654,6 @@ from typing import overload, TYPE_CHECKING, Union if TYPE_CHECKING: - from torch._C._distributed_c10d import ProcessGroup from torch.types import _device, _dtype, _int @@ -1727,8 +1731,6 @@ def rendezvous( group (Union[str, :class:`torch.distributed.ProcessGroup`]): The group identifying the participating processes. This can be either a group name or a process group object. """ - from torch._C._distributed_c10d import ProcessGroup - if isinstance(group, str): group_name = group elif isinstance(group, ProcessGroup): @@ -1746,11 +1748,7 @@ def is_nvshmem_available() -> bool: Check if NVSHMEM is available in current build and on current system. """ - try: - from torch._C._distributed_c10d import _is_nvshmem_available - except ImportError: - # Not all builds have NVSHMEM support. - return False + from torch.distributed._distributed_c10d import _is_nvshmem_available # Check if NVSHMEM is available on current system. return _is_nvshmem_available() diff --git a/torch/distributed/_symmetric_memory/_nvshmem_triton.py b/torch/distributed/_symmetric_memory/_nvshmem_triton.py index c543fdffc1c..7b7828227d7 100644 --- a/torch/distributed/_symmetric_memory/_nvshmem_triton.py +++ b/torch/distributed/_symmetric_memory/_nvshmem_triton.py @@ -75,7 +75,7 @@ def enable_triton(lib_dir: Optional[str] = None) -> dict[str, str]: """ import triton - from torch._C._distributed_c10d import _nvshmemx_cumodule_init + from torch.distributed._distributed_c10d import _nvshmemx_cumodule_init if lib_dir is not None: lib_path = os.path.join(lib_dir, "libnvshmem_device.bc") diff --git a/torch/distributed/_tools/fake_collectives.py b/torch/distributed/_tools/fake_collectives.py index 3b201b39533..b89970ab334 100644 --- a/torch/distributed/_tools/fake_collectives.py +++ b/torch/distributed/_tools/fake_collectives.py @@ -2,7 +2,9 @@ import random from typing import Any import torch -from torch._C._distributed_c10d import ( + +# Import centralized distributed components +from torch.distributed._distributed_c10d import ( _resolve_process_group, FakeWork, ProcessGroup, diff --git a/torch/distributed/constants.py b/torch/distributed/constants.py index c1e604bc867..bfa87852186 100644 --- a/torch/distributed/constants.py +++ b/torch/distributed/constants.py @@ -1,7 +1,11 @@ from datetime import timedelta from typing import Optional -from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT +# Import from centralized fallback module - no ImportError handling needed +from torch.distributed._distributed_c10d import ( + _DEFAULT_PG_NCCL_TIMEOUT, + _DEFAULT_PG_TIMEOUT, +) __all__ = ["default_pg_timeout", "default_pg_nccl_timeout"] @@ -16,11 +20,4 @@ default_pg_timeout: timedelta = _DEFAULT_PG_TIMEOUT # Later, we could consider merging them back together at the c++ layer if we can align on a same value. # (only if TORCH_NCCL_BLOCKING_WAIT or TORCH_NCCL_ASYNC_ERROR_HANDLING is set to 1). -try: - from torch._C._distributed_c10d import _DEFAULT_PG_NCCL_TIMEOUT - - default_pg_nccl_timeout: Optional[timedelta] = _DEFAULT_PG_NCCL_TIMEOUT -except ImportError: - # if C++ NCCL support is not compiled, we don't have access to the default nccl value. - # if anyone is actually trying to use nccl in this state, it should error. - default_pg_nccl_timeout = None +default_pg_nccl_timeout: Optional[timedelta] = _DEFAULT_PG_NCCL_TIMEOUT diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index c36ce0318fb..799d04ca51c 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -11,35 +11,14 @@ from itertools import chain, zip_longest from typing import Optional, TYPE_CHECKING, Union import torch -from torch.distributed import is_available from torch.utils._typing_utils import not_none __all__ = ["init_device_mesh", "DeviceMesh"] -if not is_available(): - import sys - - # We need to create the stubs when distributed is not available. - # Otherwise, we would fail the doc tests (```./.ci/pytorch/docs-test.sh```), - # since it would try to import ``torch.distributed.device_mesh`` or - # ``torch.distributed.init_device_mesh`` but cannot find them. - - class _DeviceMeshStub: - pass - - def _init_device_mesh_stub(): - pass - - sys.modules["torch.distributed.device_mesh"].DeviceMesh = _DeviceMeshStub # type: ignore[attr-defined] - sys.modules[ - "torch.distributed.device_mesh" - ].init_device_mesh = _init_device_mesh_stub # type: ignore[attr-defined] - - -else: - from torch._C._distributed_c10d import Backend as C10dBackend +if True: # just to temporarily avoid reindentation + from torch.distributed._distributed_c10d import Backend as C10dBackend from torch.distributed.distributed_c10d import ( _get_default_group, _resolve_process_group, @@ -526,15 +505,16 @@ else: # heuristic to set the current cuda/cuda-like device base on num of gpu devices available in each host # NOTE: This device selection would only work for homogeneous hardware. num_devices_per_host = device_handle.device_count() - if ( - world_size > num_devices_per_host - and world_size % num_devices_per_host != 0 - ): - raise RuntimeError( - f"DeviceMesh only support homogeneous hardware, but found " - f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!" - ) - device_handle.set_device(get_rank() % num_devices_per_host) + if num_devices_per_host: + if ( + world_size > num_devices_per_host + and world_size % num_devices_per_host != 0 + ): + raise RuntimeError( + f"DeviceMesh only support homogeneous hardware, but found " + f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!" + ) + device_handle.set_device(get_rank() % num_devices_per_host) return _get_default_group() diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 951cb2619b4..92eaaff3a51 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -19,13 +19,21 @@ from typing import Any, Callable, Optional, TYPE_CHECKING, Union from typing_extensions import deprecated import torch +import torch.distributed._distributed_c10d as _c10d from torch._C import _DistStoreError as DistStoreError -from torch._C._distributed_c10d import ( +from torch._utils_internal import set_pytorch_distributed_envs_from_justknobs +from torch.distributed._distributed_c10d import ( # Process group implementations; Availability flags _DistributedBackendOptions, + _GLOO_AVAILABLE, + _MPI_AVAILABLE, + _NCCL_AVAILABLE, + _ProcessGroupWrapper, _register_process_group, _resolve_process_group, + _UCC_AVAILABLE, _unregister_all_process_groups, _unregister_process_group, + _XCCL_AVAILABLE, AllgatherOptions, AllreduceCoalescedOptions, AllreduceOptions, @@ -37,6 +45,11 @@ from torch._C._distributed_c10d import ( get_debug_level, PrefixStore, ProcessGroup, + ProcessGroupGloo, + ProcessGroupMPI, + ProcessGroupNCCL, + ProcessGroupUCC, + ProcessGroupXCCL, ReduceOp, ReduceOptions, ReduceScatterOptions, @@ -44,7 +57,6 @@ from torch._C._distributed_c10d import ( Store, Work, ) -from torch._utils_internal import set_pytorch_distributed_envs_from_justknobs from torch.monitor import _WaitCounter from torch.overrides import handle_torch_function, has_torch_function from torch.utils._typing_utils import not_none @@ -131,17 +143,11 @@ __all__ = [ "split_group", ] -_MPI_AVAILABLE = True -_NCCL_AVAILABLE = True -_GLOO_AVAILABLE = True -_UCC_AVAILABLE = True -_XCCL_AVAILABLE = True - _pickler = pickle.Pickler _unpickler = pickle.Unpickler -# Change __module__ of all imported types from torch._C._distributed_c10d that are public +# Change __module__ of all imported types from the distributed wrapper that are public def _export_c_types() -> None: _public_types_to_change_module = [ AllreduceCoalescedOptions, @@ -167,45 +173,26 @@ def _export_c_types() -> None: _export_c_types() -try: - from torch._C._distributed_c10d import ProcessGroupMPI - +# Add process groups to __all__ and set their module based on availability +if _MPI_AVAILABLE: ProcessGroupMPI.__module__ = "torch.distributed.distributed_c10d" __all__ += ["ProcessGroupMPI"] -except ImportError: - _MPI_AVAILABLE = False - -try: - from torch._C._distributed_c10d import ProcessGroupNCCL +if _NCCL_AVAILABLE: ProcessGroupNCCL.__module__ = "torch.distributed.distributed_c10d" __all__ += ["ProcessGroupNCCL"] -except ImportError: - _NCCL_AVAILABLE = False - -try: - from torch._C._distributed_c10d import _ProcessGroupWrapper, ProcessGroupGloo +if _GLOO_AVAILABLE: ProcessGroupGloo.__module__ = "torch.distributed.distributed_c10d" __all__ += ["ProcessGroupGloo"] -except ImportError: - _GLOO_AVAILABLE = False - -try: - from torch._C._distributed_c10d import ProcessGroupUCC +if _UCC_AVAILABLE: ProcessGroupUCC.__module__ = "torch.distributed.distributed_c10d" __all__ += ["ProcessGroupUCC"] -except ImportError: - _UCC_AVAILABLE = False - -try: - from torch._C._distributed_c10d import ProcessGroupXCCL +if _XCCL_AVAILABLE: ProcessGroupXCCL.__module__ = "torch.distributed.distributed_c10d" __all__ += ["ProcessGroupXCCL"] -except ImportError: - _XCCL_AVAILABLE = False logger = logging.getLogger(__name__) @@ -1327,7 +1314,8 @@ def _get_default_store() -> Store: def _update_default_pg(pg) -> None: _world.default_pg = pg rank = pg.rank() if pg is not None and pg != GroupMember.NON_GROUP_MEMBER else -1 - torch._C._distributed_c10d._set_global_rank(rank) + + _c10d._set_global_rank(rank) def get_backend_config(group: Optional[ProcessGroup] = None) -> str: @@ -1964,7 +1952,7 @@ def _new_process_group_helper( if device_id: pg.bound_device_id = device_id - backend_class: torch._C._distributed_c10d.Backend + backend_class: _c10d.Backend for device, backend_str in backend_config.get_device_backend_map().items(): # Use the group name as prefix in the default store, such that # a single store can be reused by multiple groups. @@ -3079,7 +3067,9 @@ def _object_to_tensor(obj, device, group): if get_debug_level() == DebugLevel.DETAIL and is_nccl_available(): backend = get_backend(group) if backend == Backend.NCCL: - hash = torch._C._distributed_c10d._hash_tensors([byte_tensor]) + from torch.distributed._distributed_c10d import _hash_tensors + + hash = _hash_tensors([byte_tensor]) logger.warning( "_object_to_tensor size: %s hash value: %s", byte_tensor.numel(), @@ -3094,7 +3084,9 @@ def _tensor_to_object(tensor, tensor_size, group): if get_debug_level() == DebugLevel.DETAIL and is_nccl_available(): backend = get_backend(group) if backend == Backend.NCCL: - hash = torch._C._distributed_c10d._hash_tensors([tensor]) + from torch.distributed._distributed_c10d import _hash_tensors + + hash = _hash_tensors([tensor]) logger.warning( "_tensor_to_object size: %s hash value: %s", tensor.numel(), hash ) @@ -4971,7 +4963,7 @@ def monitored_barrier( def _create_process_group_wrapper( - wrapped_pg: torch._C._distributed_c10d.Backend, + wrapped_pg: _c10d.Backend, store_prefix: str, store: Store, rank: int, diff --git a/torch/distributed/elastic/control_plane.py b/torch/distributed/elastic/control_plane.py index 817255edd23..63334a0ca3f 100644 --- a/torch/distributed/elastic/control_plane.py +++ b/torch/distributed/elastic/control_plane.py @@ -14,7 +14,7 @@ TORCH_WORKER_SERVER_SOCKET = "TORCH_WORKER_SERVER_SOCKET" @contextmanager def _worker_server(socket_path: str) -> Generator[None, None, None]: - from torch._C._distributed_c10d import _WorkerServer + from torch.distributed._distributed_c10d import _WorkerServer server = _WorkerServer(socket_path) try: diff --git a/torch/distributed/rpc/__init__.py b/torch/distributed/rpc/__init__.py index adf901d6b6e..27a945a92e4 100644 --- a/torch/distributed/rpc/__init__.py +++ b/torch/distributed/rpc/__init__.py @@ -37,7 +37,6 @@ if is_available(): import numbers import torch.distributed.autograd as dist_autograd - from torch._C._distributed_c10d import Store from torch._C._distributed_rpc import ( # noqa: F401 _cleanup_python_rpc_handler, _DEFAULT_INIT_METHOD, @@ -70,6 +69,7 @@ if is_available(): RpcBackendOptions, WorkerInfo, ) + from torch.distributed._distributed_c10d import Store if _is_tensorpipe_available: from torch._C._distributed_rpc import ( # noqa: F401 diff --git a/torch/distributed/tensor/_collective_utils.py b/torch/distributed/tensor/_collective_utils.py index 4fce6fea538..f01836c5959 100644 --- a/torch/distributed/tensor/_collective_utils.py +++ b/torch/distributed/tensor/_collective_utils.py @@ -8,8 +8,10 @@ from typing import Optional import torch import torch.distributed._functional_collectives as funcol import torch.distributed.tensor._dtensor_spec as dtensor_spec -from torch._C._distributed_c10d import _resolve_process_group from torch._logging import warning_once + +# Import from centralized fallback module - no conditional imports needed +from torch.distributed._distributed_c10d import _resolve_process_group from torch.distributed.device_mesh import _mesh_resources, DeviceMesh from torch.distributed.distributed_c10d import ( _get_group_size_by_name, diff --git a/torch/testing/_internal/distributed/fake_pg.py b/torch/testing/_internal/distributed/fake_pg.py index e160f2fe506..a36d2da29b4 100644 --- a/torch/testing/_internal/distributed/fake_pg.py +++ b/torch/testing/_internal/distributed/fake_pg.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import torch.distributed as dist -from torch._C._distributed_c10d import FakeProcessGroup +from torch.distributed._distributed_c10d import FakeProcessGroup class FakeStore(dist.Store):