[dtensor] move DTensor to public namespace (#133113)

Moving DTensor to be in the public namespace, to formally add the
documentation page that includes all the public APIs. This includes:

* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next
  PRs)
* To preserve the BC for users still using the `torch.distributed._tensor`,
  I added a shim script to redirect old path calls to the new module

The BC preserving is evidented by the fact that all DTensor tests are still
working without changing the public imports. So it's safe to land the
changes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133113
Approved by: https://github.com/XilunWu
ghstack dependencies: #133305, #133306
This commit is contained in:
Wanchao Liang 2024-08-16 18:39:24 -07:00 committed by PyTorch MergeBot
parent 1a4709cef5
commit 2ee6b97464
88 changed files with 433 additions and 265 deletions

View File

@ -626,6 +626,17 @@ coverage_ignore_functions = [
# torch.distributed.rpc.internal # torch.distributed.rpc.internal
"deserialize", "deserialize",
"serialize", "serialize",
# torch.distributed.tensor.api
"distribute_module",
"distribute_tensor",
# torch.distributed.tensor.random
"is_rng_supported_mesh",
# torch.distributed.tensor.experimental
"context_parallel",
"local_map",
"register_sharding",
# torch.distributed.tensor.debug
"visualize_sharding",
# torch.distributed.tensor.parallel.api # torch.distributed.tensor.parallel.api
"parallelize_module", "parallelize_module",
# torch.distributed.tensor.parallel.input_reshard # torch.distributed.tensor.parallel.input_reshard
@ -2621,6 +2632,15 @@ coverage_ignore_classes = [
"RemoteException", "RemoteException",
# torch.distributed.rpc.rref_proxy # torch.distributed.rpc.rref_proxy
"RRefProxy", "RRefProxy",
# torch.distributed.tensor.api
"DTensor",
# torch.distributed.tensor.placement_types
"DTensorSpec",
"Placement",
# torch.distributed.tensor.random
"OffsetBasedRNGTracker",
# torch.distributed.tensor.debug
"CommDebugMode",
# torch.distributed.tensor.parallel.fsdp # torch.distributed.tensor.parallel.fsdp
"DTensorExtensions", "DTensorExtensions",
# torch.distributed.tensor.parallel.style # torch.distributed.tensor.parallel.style

View File

@ -876,7 +876,6 @@ If you are running single node training, it may be convenient to interactively b
.. py:module:: torch.distributed.nn.api .. py:module:: torch.distributed.nn.api
.. py:module:: torch.distributed.nn.jit .. py:module:: torch.distributed.nn.jit
.. py:module:: torch.distributed.nn.jit.templates .. py:module:: torch.distributed.nn.jit.templates
.. py:module:: torch.distributed.tensor
.. py:module:: torch.distributed.algorithms.ddp_comm_hooks.ddp_zero_hook .. py:module:: torch.distributed.algorithms.ddp_comm_hooks.ddp_zero_hook
.. py:module:: torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks .. py:module:: torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks
.. py:module:: torch.distributed.algorithms.ddp_comm_hooks.default_hooks .. py:module:: torch.distributed.algorithms.ddp_comm_hooks.default_hooks

View File

@ -0,0 +1,104 @@
.. role:: hidden
:class: hidden-section
PyTorch DTensor (Distributed Tensor)
======================================================
.. note::
``torch.distributed.tensor`` is currently in alpha state and under
development, we are committing backward compatibility for the most APIs listed
in the doc, but there might be API changes if necessary.
PyTorch DTensor offers simple and flexible tensor sharding primitives that transparently handles distributed
logic, including sharded storage, operator computation and collective communications across devices/hosts.
``DTensor`` could be used to build different paralleism solutions and support sharded state_dict representation
when working with multi-dimensional sharding.
Please see examples from the PyTorch native parallelism solutions that are built on top of ``DTensor``:
* `Tensor Parallel <https://pytorch.org/docs/main/distributed.tensor.parallel.html>`__
* `FSDP2 <https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md>`__
.. automodule:: torch.distributed.tensor
.. currentmodule:: torch.distributed.tensor
:class:`DTensor` follows the SPMD (single program, multiple data) programming model to empower users to
write distributed program as if it's a single-device program with the same convergence property. It
provides a uniform tensor sharding layout (DTensor Layout) through specifying the :class:`DeviceMesh`
and :class:`Placement`:
- :class:`DeviceMesh` represents the device topology and the communicators of the cluster using
an n-dimensional array.
- :class:`Placement` describes the sharding layout of the logical tensor on the :class:`DeviceMesh`.
DTensor supports three types of placements: :class:`Shard`, :class:`Replicate` and :class:`Partial`.
There're three ways to construct a :class:`DTensor`:
* :meth:`distribute_tensor` creates a :class:`DTensor` from a logical or "global" ``torch.Tensor`` on
each rank. This could be used to shard the leaf ``torch.Tensor`` s (i.e. model parameters/buffers
and inputs).
* :meth:`DTensor.from_local` creates a :class:`DTensor` from a local ``torch.Tensor`` on each rank, which can
be used to create :class:`DTensor` from a non-leaf ``torch.Tensor`` s (i.e. intermediate activation
tensors during forward/backward).
* DTensor provides dedicated tensor factory methods (e.g. :meth:`empty`, :meth:`ones`, :meth:`randn`, etc.)
to allow different :class:`DTensor` creations by directly specifying the :class:`DeviceMesh` and
:class:`Placement`
.. autoclass:: DTensor
:members:
:member-order: bysource
.. autofunction:: distribute_tensor
Along with :meth:`distribute_tensor`, DTensor also offers a :meth:`distribute_module` API to allow easier
sharding on the :class:`nn.Module` level
.. autofunction:: distribute_module
DTensor supports the following types of :class:`Placement` on each :class:`DeviceMesh` dimension:
.. autoclass:: Shard
:members:
:undoc-members:
.. autoclass:: Replicate
:members:
:undoc-members:
.. autoclass:: Partial
:members:
:undoc-members:
DTensor provides dedicated tensor factory functions to allow creating :class:`DTensor` directly
using torch.Tensor like factory function APIs (i.e. torch.ones, torch.empty, etc), by additionally
specifying the :class:`DeviceMesh` and :class:`Placement` for the :class:`DTensor` created:
.. autofunction:: zeros
.. autofunction:: ones
.. autofunction:: empty
.. autofunction:: full
.. autofunction:: rand
.. autofunction:: randn
.. modules that are missing docs, add the doc later when necessary
.. py:module:: torch.distributed.tensor.api
.. py:module:: torch.distributed.tensor.device_mesh
.. py:module:: torch.distributed.tensor.random
.. py:module:: torch.distributed.tensor.placement_types
.. py:module:: torch.distributed.tensor.experimental
.. py:module:: torch.distributed.tensor.experimental.attention
.. py:module:: torch.distributed.tensor.experimental.func_map
.. py:module:: torch.distributed.tensor.experimental.register_sharding
.. py:module:: torch.distributed.tensor.experimental.tp_transform
.. py:module:: torch.distributed.tensor.debug
.. py:module:: torch.distributed.tensor.debug.comm_mode
.. py:module:: torch.distributed.tensor.debug.visualize_sharding

View File

@ -74,12 +74,13 @@ Features described in this documentation are classified by release status:
torch.backends <backends> torch.backends <backends>
torch.export <export> torch.export <export>
torch.distributed <distributed> torch.distributed <distributed>
torch.distributed.tensor <distributed.tensor>
torch.distributed.algorithms.join <distributed.algorithms.join> torch.distributed.algorithms.join <distributed.algorithms.join>
torch.distributed.elastic <distributed.elastic> torch.distributed.elastic <distributed.elastic>
torch.distributed.fsdp <fsdp> torch.distributed.fsdp <fsdp>
torch.distributed.tensor.parallel <distributed.tensor.parallel>
torch.distributed.optim <distributed.optim> torch.distributed.optim <distributed.optim>
torch.distributed.pipelining <distributed.pipelining> torch.distributed.pipelining <distributed.pipelining>
torch.distributed.tensor.parallel <distributed.tensor.parallel>
torch.distributed.checkpoint <distributed.checkpoint> torch.distributed.checkpoint <distributed.checkpoint>
torch.distributions <distributions> torch.distributions <distributions>
torch.compiler <torch.compiler> torch.compiler <torch.compiler>

View File

@ -33,7 +33,8 @@
"torch.nn.quantizable": "torch.ao.nn.quantizable", "torch.nn.quantizable": "torch.ao.nn.quantizable",
"torch.nn.quantizable.modules": "torch.ao.nn.quantizable.modules", "torch.nn.quantizable.modules": "torch.ao.nn.quantizable.modules",
"torch.nn.quantizable.modules.activation": "torch.ao.nn.quantizable.modules.activation", "torch.nn.quantizable.modules.activation": "torch.ao.nn.quantizable.modules.activation",
"torch.nn.quantizable.modules.rnn": "torch.ao.nn.quantizable.modules.rnn" "torch.nn.quantizable.modules.rnn": "torch.ao.nn.quantizable.modules.rnn",
"torch.distributed.tensor.device_mesh": "torch.distributed.device_mesh"
}, },
"torch.backends": [ "torch.backends": [
"contextmanager" "contextmanager"
@ -231,6 +232,9 @@
"urlunparse" "urlunparse"
], ],
"torch.distributed.rpc": [], "torch.distributed.rpc": [],
"torch.distributed.tensor": [
"DeviceMesh"
],
"torch.fft": [ "torch.fft": [
"Tensor", "Tensor",
"fft", "fft",

View File

@ -3,9 +3,9 @@
import torch import torch
from torch.distributed._tensor import DeviceMesh from torch.distributed._tensor import DeviceMesh
from torch.distributed._tensor._op_schema import OpSchema
from torch.distributed._tensor.ops._common_rules import einop_rule, pointwise_rule
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
from torch.distributed.tensor._op_schema import OpSchema
from torch.distributed.tensor._ops._common_rules import einop_rule, pointwise_rule
from torch.testing._internal.common_utils import run_tests from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import ( from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase, DTensorTestBase,

View File

@ -909,7 +909,7 @@ class TestDTensorPlacementTypes(DTensorTestBase):
] ]
assert_array_equal(expected_pad_sizes, pad_sizes) assert_array_equal(expected_pad_sizes, pad_sizes)
from torch.distributed._tensor._collective_utils import unpad_tensor from torch.distributed.tensor._collective_utils import unpad_tensor
unpadded_list = [ unpadded_list = [
unpad_tensor(tensor, shard_placement.dim, pad_sizes[i]) unpad_tensor(tensor, shard_placement.dim, pad_sizes[i])

View File

@ -167,7 +167,7 @@ class TestEmbeddingOp(DTensorTestBase):
self._run_embedding_op_test(mesh, 0, [6, 7, 6], 13, 22) self._run_embedding_op_test(mesh, 0, [6, 7, 6], 13, 22)
self._run_embedding_op_test(mesh, 0, [34], 15, 14, padding_idx=10) self._run_embedding_op_test(mesh, 0, [34], 15, 14, padding_idx=10)
from torch.distributed._tensor.ops._embedding_ops import _MaskPartial from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
# test collectives # test collectives
embedding_mod = torch.nn.Embedding(10, 20, device=self.device_type) embedding_mod = torch.nn.Embedding(10, 20, device=self.device_type)
@ -191,7 +191,7 @@ class TestEmbeddingOp(DTensorTestBase):
inp = torch.randint(0, 10, (4, 4), device=self.device_type) inp = torch.randint(0, 10, (4, 4), device=self.device_type)
replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False) replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False)
from torch.distributed._tensor.ops._embedding_ops import _MaskPartial from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
# case 1: two embeddings with the same shape, thus sharing the underying _MaskPartial # case 1: two embeddings with the same shape, thus sharing the underying _MaskPartial
# and MaskBuffer, because of cache hit from sharding propagation # and MaskBuffer, because of cache hit from sharding propagation

View File

@ -7,8 +7,8 @@ import itertools
import torch import torch
from torch.distributed._tensor import DeviceMesh, distribute_module, distribute_tensor from torch.distributed._tensor import DeviceMesh, distribute_module, distribute_tensor
from torch.distributed._tensor.debug import CommDebugMode from torch.distributed._tensor.debug import CommDebugMode
from torch.distributed._tensor.ops.utils import is_tensor_partial, normalize_dim
from torch.distributed._tensor.placement_types import Replicate, Shard from torch.distributed._tensor.placement_types import Replicate, Shard
from torch.distributed.tensor._ops.utils import is_tensor_partial, normalize_dim
from torch.testing._internal.common_utils import run_tests from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import ( from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase, DTensorTestBase,

View File

@ -4,12 +4,6 @@ from itertools import chain
import torch import torch
from torch.distributed._tensor import DeviceMesh, DTensor from torch.distributed._tensor import DeviceMesh, DTensor
from torch.distributed._tensor._collective_utils import redistribute_cost
from torch.distributed._tensor._op_schema import OpSchema, OpStrategy, PlacementStrategy
from torch.distributed._tensor.ops._einsum_strategy import (
EinsumDims,
gen_einsum_strategies,
)
from torch.distributed._tensor.placement_types import ( from torch.distributed._tensor.placement_types import (
DTensorSpec, DTensorSpec,
Partial, Partial,
@ -17,6 +11,12 @@ from torch.distributed._tensor.placement_types import (
Shard, Shard,
TensorMeta, TensorMeta,
) )
from torch.distributed.tensor._collective_utils import redistribute_cost
from torch.distributed.tensor._op_schema import OpSchema, OpStrategy, PlacementStrategy
from torch.distributed.tensor._ops._einsum_strategy import (
EinsumDims,
gen_einsum_strategies,
)
from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase
@ -169,7 +169,7 @@ class TestCostModel(DTensorOpTestBase):
def test_redistribute_cost_latency(self): def test_redistribute_cost_latency(self):
# test cost model on addmm op # test cost model on addmm op
from torch.distributed._tensor.ops._matrix_ops import addmm_strategy from torch.distributed.tensor._ops._matrix_ops import addmm_strategy
mesh = self.build_device_mesh() mesh = self.build_device_mesh()
shard0_placement = (Shard(0),) shard0_placement = (Shard(0),)
@ -246,7 +246,7 @@ class TestCostModel(DTensorOpTestBase):
self.assertTrue(allreduce_cost > reduce_scatter_cost) self.assertTrue(allreduce_cost > reduce_scatter_cost)
def test_mm_strategies(self): def test_mm_strategies(self):
from torch.distributed._tensor.ops._matrix_ops import mm_strategy from torch.distributed.tensor._ops._matrix_ops import mm_strategy
mesh = self.build_device_mesh() mesh = self.build_device_mesh()
lhs_tensor = torch.randn(6, 8) lhs_tensor = torch.randn(6, 8)
@ -292,7 +292,7 @@ class TestCostModel(DTensorOpTestBase):
self.assertFalse(output_sharding.needs_redistribute) self.assertFalse(output_sharding.needs_redistribute)
def test_bmm_strategies(self): def test_bmm_strategies(self):
from torch.distributed._tensor.ops._matrix_ops import bmm_strategy from torch.distributed.tensor._ops._matrix_ops import bmm_strategy
mesh = self.build_device_mesh() mesh = self.build_device_mesh()
lhs_tensor = torch.randn(8, 6, 8) lhs_tensor = torch.randn(8, 6, 8)

View File

@ -5,10 +5,10 @@ import itertools
import torch import torch
from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor
from torch.distributed._tensor._collective_utils import shard_dim_alltoall
from torch.distributed._tensor.debug import CommDebugMode from torch.distributed._tensor.debug import CommDebugMode
from torch.distributed._tensor.placement_types import Partial, Replicate, Shard from torch.distributed._tensor.placement_types import Partial, Replicate, Shard
from torch.distributed.device_mesh import init_device_mesh from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor._collective_utils import shard_dim_alltoall
from torch.testing._internal.common_utils import run_tests from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import ( from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase, DTensorTestBase,
@ -207,7 +207,7 @@ class RedistributeTest(DTensorTestBase):
with self.assertRaisesRegex(RuntimeError, "Can not redistribute to Partial"): with self.assertRaisesRegex(RuntimeError, "Can not redistribute to Partial"):
partial_tensor = replica_tensor.redistribute(device_mesh, [partial_spec]) partial_tensor = replica_tensor.redistribute(device_mesh, [partial_spec])
from torch.distributed._tensor._redistribute import Redistribute from torch.distributed.tensor._redistribute import Redistribute
comm_mode = CommDebugMode() comm_mode = CommDebugMode()

View File

@ -445,7 +445,7 @@ class DistTensorOpsTest(DTensorTestBase):
# case 2 input sharding: input sharded, index replicated, output mask partial # case 2 input sharding: input sharded, index replicated, output mask partial
# only works when index has size 1 on the gather dimension and # only works when index has size 1 on the gather dimension and
# input is sharded on the gather dimension # input is sharded on the gather dimension
from torch.distributed._tensor.ops._embedding_ops import _MaskPartial from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
gather_dim = 1 gather_dim = 1
global_input = torch.randn(12, 8, 16) global_input = torch.randn(12, 8, 16)

View File

@ -9,7 +9,8 @@ import torch.distributed as dist
from torch import rand, randn, Tensor from torch import rand, randn, Tensor
from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard
from torch.distributed._tensor.debug import CommDebugMode from torch.distributed._tensor.debug import CommDebugMode
from torch.distributed._tensor.ops._view_ops import ( from torch.distributed._tensor.placement_types import Placement
from torch.distributed.tensor._ops._view_ops import (
Broadcast, Broadcast,
dim_maps, dim_maps,
Flatten, Flatten,
@ -19,7 +20,6 @@ from torch.distributed._tensor.ops._view_ops import (
Split, Split,
view_groups, view_groups,
) )
from torch.distributed._tensor.placement_types import Placement
from torch.testing._internal.common_utils import run_tests from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import ( from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase, DTensorTestBase,

View File

@ -5,11 +5,6 @@ import os
import torch import torch
import torch.distributed._functional_collectives as funcol import torch.distributed._functional_collectives as funcol
from torch.distributed._tensor import DTensor from torch.distributed._tensor import DTensor
from torch.distributed._tensor._collective_utils import (
mesh_broadcast,
mesh_scatter,
unpad_tensor,
)
from torch.distributed._tensor.placement_types import _Partial, Shard from torch.distributed._tensor.placement_types import _Partial, Shard
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh
from torch.distributed.distributed_c10d import ( from torch.distributed.distributed_c10d import (
@ -22,6 +17,11 @@ from torch.distributed.distributed_c10d import (
is_nccl_available, is_nccl_available,
ProcessGroup, ProcessGroup,
) )
from torch.distributed.tensor._collective_utils import (
mesh_broadcast,
mesh_scatter,
unpad_tensor,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import ( from torch.testing._internal.distributed._tensor.common_dtensor import (

View File

@ -1432,12 +1432,12 @@ class GuardBuilder(GuardBuilderBase):
} }
) )
if torch.distributed.is_available(): if torch.distributed.is_available():
from torch.distributed._tensor.placement_types import ( from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.placement_types import (
Partial, Partial,
Replicate, Replicate,
Shard, Shard,
) )
from torch.distributed.device_mesh import DeviceMesh
ok_types = ok_types + ( ok_types = ok_types + (
Shard, Shard,

View File

@ -142,7 +142,7 @@ manual_torch_name_rule_map = {
"torch.distributed.is_initialized": TorchInGraphFunctionVariable, "torch.distributed.is_initialized": TorchInGraphFunctionVariable,
"torch.distributed.get_rank": TorchInGraphFunctionVariable, "torch.distributed.get_rank": TorchInGraphFunctionVariable,
"torch.distributed.get_world_size": TorchInGraphFunctionVariable, "torch.distributed.get_world_size": TorchInGraphFunctionVariable,
"torch.distributed._tensor.api.DTensor#from_local": TorchInGraphFunctionVariable, "torch.distributed.tensor.api.DTensor#from_local": TorchInGraphFunctionVariable,
"torch.distributed.distributed_c10d._get_group_size_by_name": TorchInGraphFunctionVariable, "torch.distributed.distributed_c10d._get_group_size_by_name": TorchInGraphFunctionVariable,
"torch.distributed.distributed_c10d._resolve_group_name_by_ranks_and_tag": TorchInGraphFunctionVariable, "torch.distributed.distributed_c10d._resolve_group_name_by_ranks_and_tag": TorchInGraphFunctionVariable,
"torch.distributed.distributed_c10d._get_group_tag": TorchInGraphFunctionVariable, "torch.distributed.distributed_c10d._get_group_tag": TorchInGraphFunctionVariable,
@ -3190,8 +3190,8 @@ LEGACY_MOD_INLINELIST = {
if torch.distributed.is_available(): if torch.distributed.is_available():
LEGACY_MOD_INLINELIST |= { LEGACY_MOD_INLINELIST |= {
"torch.distributed._tensor.api", "torch.distributed.tensor.api",
"torch.distributed._tensor.device_mesh", "torch.distributed.tensor.device_mesh",
"torch.distributed.device_mesh", "torch.distributed.device_mesh",
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper", "torch.distributed.algorithms._checkpoint.checkpoint_wrapper",
"torch.distributed.tensor.parallel._data_parallel_utils", "torch.distributed.tensor.parallel._data_parallel_utils",

View File

@ -50,7 +50,7 @@ class DistributedVariable(VariableTracker):
def is_from_local(value): def is_from_local(value):
if not DistributedVariable.is_available(): if not DistributedVariable.is_available():
return False return False
from torch.distributed._tensor import DTensor from torch.distributed.tensor import DTensor
return inspect.isfunction(value) and value is DTensor.from_local return inspect.isfunction(value) and value is DTensor.from_local
@ -108,7 +108,7 @@ class PlacementClassVariable(DistributedVariable):
if not DistributedVariable.is_available(): if not DistributedVariable.is_available():
return False return False
from torch.distributed._tensor.placement_types import Placement from torch.distributed.tensor.placement_types import Placement
return type(value) is type and issubclass(value, Placement) return type(value) is type and issubclass(value, Placement)
@ -143,7 +143,7 @@ class PlacementVariable(DistributedVariable):
if not DistributedVariable.is_available(): if not DistributedVariable.is_available():
return False return False
from torch.distributed._tensor.placement_types import Placement from torch.distributed.tensor.placement_types import Placement
return isinstance(value, Placement) return isinstance(value, Placement)

View File

@ -598,7 +598,6 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
) )
if DistributedVariable.is_available(): if DistributedVariable.is_available():
from torch.distributed._tensor import DTensor
from torch.distributed.distributed_c10d import ( from torch.distributed.distributed_c10d import (
_get_group_size_by_name, _get_group_size_by_name,
_get_group_tag, _get_group_tag,
@ -606,6 +605,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
_resolve_group_name_by_ranks_and_tag, _resolve_group_name_by_ranks_and_tag,
get_process_group_ranks, get_process_group_ranks,
) )
from torch.distributed.tensor import DTensor
@register( @register(
_get_group_size_by_name, _get_group_size_by_name,

View File

@ -4,8 +4,8 @@ from typing import cast, List, NamedTuple, Optional, Tuple, Union
import torch import torch
import torch._dynamo.compiled_autograd as ca import torch._dynamo.compiled_autograd as ca
import torch.distributed as dist import torch.distributed as dist
from torch.distributed._tensor import DTensor
from torch.distributed.distributed_c10d import ReduceOp from torch.distributed.distributed_c10d import ReduceOp
from torch.distributed.tensor import DTensor
from ._fsdp_common import ( from ._fsdp_common import (
_get_dim0_padded_size, _get_dim0_padded_size,

View File

@ -10,8 +10,8 @@ import torch._dynamo.compiled_autograd as ca
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.distributed._composable.contract import _get_registry from torch.distributed._composable.contract import _get_registry
from torch.distributed._tensor import DeviceMesh, DTensor from torch.distributed.tensor import DeviceMesh, DTensor
from torch.distributed._tensor.placement_types import DTensorSpec from torch.distributed.tensor.placement_types import DTensorSpec
@dataclass @dataclass

View File

@ -4,8 +4,8 @@ from typing import List, Optional, Set, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.distributed._tensor import DeviceMesh, DTensor, init_device_mesh
from torch.distributed.device_mesh import _get_device_handle from torch.distributed.device_mesh import _get_device_handle
from torch.distributed.tensor import DeviceMesh, DTensor, init_device_mesh
from torch.utils._python_dispatch import is_traceable_wrapper_subclass from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from ._fsdp_common import _is_composable_with_fsdp, FSDPMeshInfo, HSDPMeshInfo from ._fsdp_common import _is_composable_with_fsdp, FSDPMeshInfo, HSDPMeshInfo

View File

@ -9,9 +9,9 @@ import torch._dynamo.compiled_autograd as ca
import torch.nn as nn import torch.nn as nn
from torch._prims_common import make_contiguous_strides_for from torch._prims_common import make_contiguous_strides_for
from torch.distributed._functional_collectives import AsyncCollectiveTensor from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed._tensor import DTensor, Replicate, Shard from torch.distributed.tensor import DTensor, Replicate, Shard
from torch.distributed._tensor.device_mesh import _mesh_resources from torch.distributed.tensor.device_mesh import _mesh_resources
from torch.distributed._tensor.placement_types import ( from torch.distributed.tensor.placement_types import (
_StridedShard, _StridedShard,
DTensorSpec, DTensorSpec,
Placement, Placement,

View File

@ -6,7 +6,7 @@ from typing import Any, cast, Iterable, List, NoReturn, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.distributed._composable import contract from torch.distributed._composable import contract
from torch.distributed._tensor import DeviceMesh from torch.distributed.tensor import DeviceMesh
from torch.distributed.utils import _get_root_modules from torch.distributed.utils import _get_root_modules
from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy

View File

@ -97,7 +97,7 @@ RANK_TYPES = Union[
List[List[int]], List[List[int]],
dist.ProcessGroup, dist.ProcessGroup,
DeviceMesh, DeviceMesh,
Tuple["dist._tensor.DeviceMesh", int], Tuple["dist.tensor.DeviceMesh", int],
str, str,
] ]

View File

@ -27,7 +27,7 @@ from torch.distributed._functional_collectives import AsyncCollectiveTensor
if dist.is_available() or TYPE_CHECKING: if dist.is_available() or TYPE_CHECKING:
from torch.distributed import distributed_c10d from torch.distributed import distributed_c10d
from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._tensor import distribute_tensor, DTensor, Replicate from torch.distributed.tensor import distribute_tensor, DTensor, Replicate
def _identity_func( def _identity_func(

View File

@ -1,58 +1,43 @@
# mypy: allow-untyped-defs """
# Copyright (c) Meta Platforms, Inc. and affiliates NOTICE: DTensor has moved to torch.distributed.tensor
import torch This file is a shim to redirect to the new location, and
import torch.distributed._tensor.ops as _ops # force import all built-in dtensor ops we keep the old import path starts with `_tensor` for
from torch.distributed._tensor.api import ( backward compatibility.
distribute_module, """
distribute_tensor, import importlib
DTensor, import sys
empty,
full, import torch.distributed.tensor
ones,
rand,
randn,
zeros,
)
from torch.distributed._tensor.placement_types import (
Partial,
Placement,
Replicate,
Shard,
)
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.optim.optimizer import (
_foreach_supported_types as _optim_foreach_supported_types,
)
from torch.utils._foreach_utils import (
_foreach_supported_types as _util_foreach_supported_types,
)
# All public APIs from dtensor package def _populate(): # type: ignore[no-untyped-def]
__all__ = [ for name in (
"DTensor", # TODO: _utils here mainly for checkpoint imports BC, remove it
"DeviceMesh", "_utils",
"distribute_tensor", "api",
"distribute_module", "debug",
"init_device_mesh,", "device_mesh",
"Shard", "experimental",
"Replicate", "placement_types",
"Partial", "random",
"Placement", ):
"ones", try:
"empty", globals()[name] = sys.modules[
"full", f"torch.distributed._tensor.{name}"
"rand", ] = importlib.import_module(f"torch.distributed.tensor.{name}")
"randn", except ImportError as e:
"zeros", import traceback
]
traceback.print_exc()
raise ImportError(
f"Failed to import torch.distributed.tensor.{name} due to {e}"
) from e
for name, val in torch.distributed.tensor.__dict__.items():
# Skip private names and tensor parallel package
if not name.startswith("_") and name != "parallel":
globals()[name] = val
# Append DTensor to the list of supported types for foreach implementation for optimizer _populate()
# and clip_grad_norm_ so that we will try to use foreach over the for-loop implementation on CUDA.
if DTensor not in _optim_foreach_supported_types:
_optim_foreach_supported_types.append(DTensor)
if DTensor not in _util_foreach_supported_types:
_util_foreach_supported_types.append(DTensor)

View File

@ -14,8 +14,8 @@ from typing import (
import torch import torch
from torch.distributed._shard.sharded_tensor.api import ShardedTensor from torch.distributed._shard.sharded_tensor.api import ShardedTensor
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
from torch.distributed.tensor import DTensor
PATH_ITEM = Union[str, int] PATH_ITEM = Union[str, int]

View File

@ -11,7 +11,6 @@ from typing import Any, cast, Dict, List, Optional, Tuple, Union
import torch import torch
from torch.distributed._shard._utils import narrow_tensor_by_index from torch.distributed._shard._utils import narrow_tensor_by_index
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans
from torch.distributed.checkpoint._nested_dict import ( from torch.distributed.checkpoint._nested_dict import (
FLATTEN_MAPPING, FLATTEN_MAPPING,
@ -45,6 +44,7 @@ from torch.distributed.checkpoint.planner_helpers import (
_init_state_dict, _init_state_dict,
) )
from torch.distributed.checkpoint.utils import find_state_dict_object from torch.distributed.checkpoint.utils import find_state_dict_object
from torch.distributed.tensor import DTensor
logger: logging.Logger = logging.getLogger(__name__) logger: logging.Logger = logging.getLogger(__name__)

View File

@ -11,12 +11,12 @@ import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.distributed.checkpoint.state_dict import ( from torch.distributed.checkpoint.state_dict import (
_patch_model_state_dict, _patch_model_state_dict,
_patch_optimizer_state_dict, _patch_optimizer_state_dict,
) )
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.tensor.device_mesh import init_device_mesh
DEVICE = "cuda" DEVICE = "cuda"

View File

@ -12,11 +12,11 @@ import torch.distributed as dist
import torch.distributed.checkpoint as dcp import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.distributed.checkpoint.state_dict import ( from torch.distributed.checkpoint.state_dict import (
_patch_model_state_dict, _patch_model_state_dict,
_patch_optimizer_state_dict, _patch_optimizer_state_dict,
) )
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

View File

@ -12,7 +12,6 @@ from torch.distributed._shard.sharded_tensor.metadata import (
) )
from torch.distributed._shard.sharded_tensor.shard import Shard from torch.distributed._shard.sharded_tensor.shard import Shard
from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint._nested_dict import unflatten_state_dict from torch.distributed.checkpoint._nested_dict import unflatten_state_dict
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
from torch.distributed.checkpoint.metadata import ( from torch.distributed.checkpoint.metadata import (
@ -39,6 +38,7 @@ from torch.distributed.checkpoint.utils import (
from torch.distributed.distributed_c10d import _get_default_group from torch.distributed.distributed_c10d import _get_default_group
from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
from torch.distributed.remote_device import _remote_device from torch.distributed.remote_device import _remote_device
from torch.distributed.tensor import DTensor
STATE_DICT_2D_LAYOUT = Dict[str, Tuple[Optional[Sequence[int]], Sequence[int]]] STATE_DICT_2D_LAYOUT = Dict[str, Tuple[Optional[Sequence[int]], Sequence[int]]]

View File

@ -7,8 +7,8 @@ import torch.distributed as dist
from torch._utils import _get_device_module from torch._utils import _get_device_module
from torch.distributed._shard.metadata import ShardMetadata from torch.distributed._shard.metadata import ShardMetadata
from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._tensor import DTensor from torch.distributed.tensor import DTensor
from torch.distributed._tensor._utils import compute_local_shape_and_global_offset from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
from .metadata import ( from .metadata import (
BytesStorageMetadata, BytesStorageMetadata,

View File

@ -32,7 +32,6 @@ from torch.distributed._state_dict_utils import (
_offload_state_dict_to_cpu, _offload_state_dict_to_cpu,
_unflatten_state_dict, _unflatten_state_dict,
) )
from torch.distributed._tensor import DTensor
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
_CHECKPOINT_PREFIX, _CHECKPOINT_PREFIX,
) )
@ -50,6 +49,7 @@ from torch.distributed.fsdp._common_utils import (
_get_module_fsdp_state_if_fully_sharded_module, _get_module_fsdp_state_if_fully_sharded_module,
FSDP_WRAPPED_MODULE, FSDP_WRAPPED_MODULE,
) )
from torch.distributed.tensor import DTensor
from torch.nn.modules.module import _IncompatibleKeys from torch.nn.modules.module import _IncompatibleKeys
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils._pytree import tree_map_only from torch.utils._pytree import tree_map_only

View File

@ -1887,7 +1887,7 @@ class FlatParamHandle:
flat_param = self.flat_param flat_param = self.flat_param
self._check_unsharded(flat_param) self._check_unsharded(flat_param)
views = self._get_unflat_views() views = self._get_unflat_views()
from torch.distributed._tensor import DTensor from torch.distributed.tensor import DTensor
for i, (view, (param_name, module, _)) in enumerate( for i, (view, (param_name, module, _)) in enumerate(
zip(views, flat_param._param_infos) zip(views, flat_param._param_infos)
@ -2717,7 +2717,7 @@ def _warn_use_fake_reduce(log: logging.Logger, warning: str):
def _same_storage(a, b): def _same_storage(a, b):
# Params are DTensors in backward # Params are DTensors in backward
# with SHARD_GRAD_OP + TP # with SHARD_GRAD_OP + TP
from torch.distributed._tensor import DTensor from torch.distributed.tensor import DTensor
if isinstance(a, DTensor): if isinstance(a, DTensor):
a = a._local_tensor a = a._local_tensor

View File

@ -5,12 +5,12 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed._shard.sharded_tensor.api import ShardedTensor from torch.distributed._shard.sharded_tensor.api import ShardedTensor
from torch.distributed._shard.sharded_tensor.shard import Shard from torch.distributed._shard.sharded_tensor.shard import Shard
from torch.distributed._tensor import DeviceMesh, DTensor
from torch.distributed.fsdp._shard_utils import ( from torch.distributed.fsdp._shard_utils import (
_all_gather_dtensor, _all_gather_dtensor,
_create_chunk_dtensor, _create_chunk_dtensor,
_create_chunk_sharded_tensor, _create_chunk_sharded_tensor,
) )
from torch.distributed.tensor import DeviceMesh, DTensor
class FSDPExtensions(ABC): class FSDPExtensions(ABC):

View File

@ -27,7 +27,6 @@ import torch.distributed as dist
import torch.distributed.fsdp._traversal_utils as traversal_utils import torch.distributed.fsdp._traversal_utils as traversal_utils
import torch.nn as nn import torch.nn as nn
from torch.distributed._state_dict_utils import _gather_state_dict from torch.distributed._state_dict_utils import _gather_state_dict
from torch.distributed._tensor import DTensor, Replicate
from torch.distributed.distributed_c10d import _get_pg_default_device from torch.distributed.distributed_c10d import _get_pg_default_device
from torch.distributed.fsdp._common_utils import ( from torch.distributed.fsdp._common_utils import (
_apply_to_modules, _apply_to_modules,
@ -53,6 +52,7 @@ from torch.distributed.fsdp.api import (
StateDictSettings, StateDictSettings,
StateDictType, StateDictType,
) )
from torch.distributed.tensor import DTensor, Replicate
from torch.utils._pytree import tree_map_only from torch.utils._pytree import tree_map_only

View File

@ -15,7 +15,7 @@ from torch.distributed._shard.sharded_tensor import (
TensorProperties, TensorProperties,
) )
from torch.distributed._shard.sharding_spec import ShardMetadata from torch.distributed._shard.sharding_spec import ShardMetadata
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard as DShard from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard
def _get_remote_device_str(rank, device_type, num_devices_per_node): def _get_remote_device_str(rank, device_type, num_devices_per_node):

View File

@ -25,7 +25,6 @@ from torch.distributed._shard.sharded_tensor import (
Shard, Shard,
ShardedTensor, ShardedTensor,
) )
from torch.distributed._tensor import DTensor
from torch.distributed.device_mesh import _mesh_resources from torch.distributed.device_mesh import _mesh_resources
from torch.distributed.fsdp._common_utils import ( from torch.distributed.fsdp._common_utils import (
_FSDPState, _FSDPState,
@ -49,6 +48,7 @@ from torch.distributed.fsdp.api import (
ShardingStrategy, ShardingStrategy,
StateDictType, StateDictType,
) )
from torch.distributed.tensor import DTensor
from torch.distributed.utils import _replace_by_prefix from torch.distributed.utils import _replace_by_prefix
from ._fsdp_extensions import ( from ._fsdp_extensions import (

View File

@ -25,7 +25,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.distributed.fsdp._traversal_utils as traversal_utils import torch.distributed.fsdp._traversal_utils as traversal_utils
import torch.nn as nn import torch.nn as nn
from torch.distributed._tensor import DeviceMesh
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
_CHECKPOINT_WRAPPED_MODULE, _CHECKPOINT_WRAPPED_MODULE,
ActivationWrapper, ActivationWrapper,
@ -84,6 +83,7 @@ from torch.distributed.fsdp.api import (
StateDictSettings, StateDictSettings,
StateDictType, StateDictType,
) )
from torch.distributed.tensor import DeviceMesh
from torch.distributed.utils import _p_assert from torch.distributed.utils import _p_assert
from ._flat_param import FlatParameter, FlatParamHandle from ._flat_param import FlatParameter, FlatParamHandle

View File

@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import torch
import torch.distributed.tensor._ops # force import all built-in dtensor ops
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.tensor.api import (
distribute_module,
distribute_tensor,
DTensor,
empty,
full,
ones,
rand,
randn,
zeros,
)
from torch.distributed.tensor.placement_types import (
Partial,
Placement,
Replicate,
Shard,
)
from torch.optim.optimizer import (
_foreach_supported_types as _optim_foreach_supported_types,
)
from torch.utils._foreach_utils import (
_foreach_supported_types as _util_foreach_supported_types,
)
# All public APIs from dtensor package
__all__ = [
"DTensor",
"DeviceMesh",
"distribute_tensor",
"distribute_module",
"init_device_mesh,",
"Shard",
"Replicate",
"Partial",
"Placement",
"ones",
"empty",
"full",
"rand",
"randn",
"zeros",
]
# Append DTensor to the list of supported types for foreach implementation for optimizer
# and clip_grad_norm_ so that we will try to use foreach over the for-loop implementation on CUDA.
if DTensor not in _optim_foreach_supported_types:
_optim_foreach_supported_types.append(DTensor)
if DTensor not in _util_foreach_supported_types:
_util_foreach_supported_types.append(DTensor)

View File

@ -7,7 +7,7 @@ from typing import List, Optional
import torch import torch
import torch.distributed._functional_collectives as funcol import torch.distributed._functional_collectives as funcol
import torch.distributed._tensor.placement_types as placement_types import torch.distributed.tensor.placement_types as placement_types
from torch._C._distributed_c10d import _resolve_process_group from torch._C._distributed_c10d import _resolve_process_group
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
from torch.distributed.distributed_c10d import ( from torch.distributed.distributed_c10d import (

View File

@ -8,24 +8,24 @@ from typing import cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.distributed._tensor.api as dtensor import torch.distributed.tensor.api as dtensor
import torch.distributed._tensor.random as random import torch.distributed.tensor.random as random
from torch.distributed._tensor._op_schema import ( from torch.distributed.tensor._op_schema import (
_is_inplace_op, _is_inplace_op,
_is_out_variant_op, _is_out_variant_op,
OpInfo, OpInfo,
OpSchema, OpSchema,
OutputSpecType, OutputSpecType,
) )
from torch.distributed._tensor._redistribute import redistribute_local_tensor from torch.distributed.tensor._redistribute import redistribute_local_tensor
from torch.distributed._tensor._sharding_prop import ShardingPropagator from torch.distributed.tensor._sharding_prop import ShardingPropagator
from torch.distributed._tensor._tp_conv import ( from torch.distributed.tensor._tp_conv import (
convolution_backward_handler, convolution_backward_handler,
convolution_handler, convolution_handler,
) )
from torch.distributed._tensor._utils import try_find_mesh_from_args from torch.distributed.tensor._utils import try_find_mesh_from_args
from torch.distributed._tensor.placement_types import DTensorSpec, Replicate, TensorMeta from torch.distributed.tensor.placement_types import DTensorSpec, Replicate, TensorMeta
from torch.distributed._tensor.random import is_rng_supported_mesh from torch.distributed.tensor.random import is_rng_supported_mesh
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@ -5,8 +5,8 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import torch import torch
from torch._ops import OpOverload from torch._ops import OpOverload
from torch.distributed._tensor.placement_types import DTensorSpec, Placement
from torch.distributed.device_mesh import DeviceMesh from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.placement_types import DTensorSpec, Placement
try: try:

View File

@ -2,15 +2,15 @@
from typing import cast, Dict, List, Optional, Tuple from typing import cast, Dict, List, Optional, Tuple
import torch import torch
from torch.distributed._tensor._op_schema import ( from torch.distributed.tensor._op_schema import (
_is_inplace_op, _is_inplace_op,
_is_out_variant_op, _is_out_variant_op,
OpSchema, OpSchema,
OutputSharding, OutputSharding,
) )
from torch.distributed._tensor._utils import compute_local_shape from torch.distributed.tensor._ops.utils import prod
from torch.distributed._tensor.ops.utils import prod from torch.distributed.tensor._utils import compute_local_shape
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta from torch.distributed.tensor.placement_types import DTensorSpec, TensorMeta
def _replace_char_in_str(string: str, new_char: str, idx: int) -> str: def _replace_char_in_str(string: str, new_char: str, idx: int) -> str:

View File

@ -4,9 +4,9 @@
from typing import List from typing import List
import torch import torch
from torch.distributed._tensor._op_schema import OpSchema, OutputSharding from torch.distributed.tensor._op_schema import OpSchema, OutputSharding
from torch.distributed._tensor.ops.utils import register_prop_rule from torch.distributed.tensor._ops.utils import register_prop_rule
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta from torch.distributed.tensor.placement_types import DTensorSpec, TensorMeta
aten = torch.ops.aten aten = torch.ops.aten

View File

@ -2,15 +2,15 @@ import itertools
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Set, Tuple from typing import List, Set, Tuple
from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy from torch.distributed.device_mesh import DeviceMesh
from torch.distributed._tensor.placement_types import ( from torch.distributed.tensor._op_schema import OpStrategy, PlacementStrategy
from torch.distributed.tensor.placement_types import (
DTensorSpec, DTensorSpec,
Partial, Partial,
Placement, Placement,
Replicate, Replicate,
Shard, Shard,
) )
from torch.distributed.device_mesh import DeviceMesh
@dataclass @dataclass

View File

@ -7,23 +7,23 @@ from typing import cast, Optional
import torch import torch
import torch.distributed._functional_collectives as funcol import torch.distributed._functional_collectives as funcol
from torch.distributed._tensor._op_schema import ( from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._op_schema import (
OpSchema, OpSchema,
OpStrategy, OpStrategy,
PlacementList, PlacementList,
StrategyType, StrategyType,
) )
from torch.distributed._tensor.ops.utils import ( from torch.distributed.tensor._ops.utils import (
expand_to_full_mesh_op_strategy, expand_to_full_mesh_op_strategy,
register_op_strategy, register_op_strategy,
) )
from torch.distributed._tensor.placement_types import ( from torch.distributed.tensor.placement_types import (
Partial, Partial,
Placement, Placement,
Replicate, Replicate,
Shard, Shard,
) )
from torch.distributed.device_mesh import DeviceMesh
aten = torch.ops.aten aten = torch.ops.aten

View File

@ -3,15 +3,15 @@
# implement matrix related ops for distributed tensor # implement matrix related ops for distributed tensor
import torch import torch
from torch.distributed._tensor._op_schema import ( from torch.distributed.tensor._op_schema import (
OpSchema, OpSchema,
OpStrategy, OpStrategy,
PlacementStrategy, PlacementStrategy,
StrategyType, StrategyType,
) )
from torch.distributed._tensor.device_mesh import DeviceMesh from torch.distributed.tensor._ops.utils import register_op_strategy
from torch.distributed._tensor.ops.utils import register_op_strategy from torch.distributed.tensor.device_mesh import DeviceMesh
from torch.distributed._tensor.placement_types import DTensorSpec, Replicate from torch.distributed.tensor.placement_types import DTensorSpec, Replicate
aten = torch.ops.aten aten = torch.ops.aten

View File

@ -7,7 +7,8 @@ from enum import Enum
from typing import cast, List, Optional, Sequence, Tuple, Union from typing import cast, List, Optional, Sequence, Tuple, Union
import torch import torch
from torch.distributed._tensor._op_schema import ( from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._op_schema import (
OpSchema, OpSchema,
OpStrategy, OpStrategy,
PlacementList, PlacementList,
@ -15,8 +16,7 @@ from torch.distributed._tensor._op_schema import (
RuntimeSchemaInfo, RuntimeSchemaInfo,
TupleStrategy, TupleStrategy,
) )
from torch.distributed._tensor._utils import normalize_to_torch_size from torch.distributed.tensor._ops.utils import (
from torch.distributed._tensor.ops.utils import (
as_list, as_list,
expand_to_full_mesh_op_strategy, expand_to_full_mesh_op_strategy,
generate_redistribute_costs, generate_redistribute_costs,
@ -25,14 +25,14 @@ from torch.distributed._tensor.ops.utils import (
normalize_dims, normalize_dims,
register_op_strategy, register_op_strategy,
) )
from torch.distributed._tensor.placement_types import ( from torch.distributed.tensor._utils import normalize_to_torch_size
from torch.distributed.tensor.placement_types import (
DTensorSpec, DTensorSpec,
Partial, Partial,
Placement, Placement,
Replicate, Replicate,
Shard, Shard,
) )
from torch.distributed.device_mesh import DeviceMesh
aten = torch.ops.aten aten = torch.ops.aten

View File

@ -5,14 +5,15 @@
from typing import List from typing import List
import torch import torch
from torch.distributed._tensor._op_schema import ( from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._op_schema import (
OpSchema, OpSchema,
OpStrategy, OpStrategy,
PlacementList, PlacementList,
PlacementStrategy, PlacementStrategy,
) )
from torch.distributed._tensor.ops._einsum_strategy import gen_einsum_strategies from torch.distributed.tensor._ops._einsum_strategy import gen_einsum_strategies
from torch.distributed._tensor.ops.utils import ( from torch.distributed.tensor._ops.utils import (
expand_to_full_mesh_op_strategy, expand_to_full_mesh_op_strategy,
generate_redistribute_costs, generate_redistribute_costs,
infer_broadcast_dims_map, infer_broadcast_dims_map,
@ -20,13 +21,12 @@ from torch.distributed._tensor.ops.utils import (
map_placements_after_broadcast, map_placements_after_broadcast,
register_op_strategy, register_op_strategy,
) )
from torch.distributed._tensor.placement_types import ( from torch.distributed.tensor.placement_types import (
DTensorSpec, DTensorSpec,
Placement, Placement,
Replicate, Replicate,
Shard, Shard,
) )
from torch.distributed.device_mesh import DeviceMesh
aten = torch.ops.aten aten = torch.ops.aten

View File

@ -2,7 +2,8 @@
from typing import List, Sequence, Tuple from typing import List, Sequence, Tuple
import torch import torch
from torch.distributed._tensor._op_schema import ( from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._op_schema import (
_is_inplace_op, _is_inplace_op,
_is_out_variant_op, _is_out_variant_op,
OpSchema, OpSchema,
@ -12,21 +13,20 @@ from torch.distributed._tensor._op_schema import (
StrategyType, StrategyType,
TupleStrategy, TupleStrategy,
) )
from torch.distributed._tensor.ops.utils import ( from torch.distributed.tensor._ops.utils import (
generate_redistribute_costs, generate_redistribute_costs,
infer_broadcast_dims_map, infer_broadcast_dims_map,
map_placements_after_broadcast, map_placements_after_broadcast,
normalize_dim, normalize_dim,
register_op_strategy, register_op_strategy,
) )
from torch.distributed._tensor.placement_types import ( from torch.distributed.tensor.placement_types import (
DTensorSpec, DTensorSpec,
Partial, Partial,
Placement, Placement,
Replicate, Replicate,
Shard, Shard,
) )
from torch.distributed.device_mesh import DeviceMesh
aten = torch.ops.aten aten = torch.ops.aten

View File

@ -1,14 +1,14 @@
# mypy: allow-untyped-decorators # mypy: allow-untyped-decorators
# Copyright (c) Meta Platforms, Inc. and affiliates # Copyright (c) Meta Platforms, Inc. and affiliates
import torch import torch
from torch.distributed._tensor._op_schema import ( from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._op_schema import (
OpSchema, OpSchema,
OpStrategy, OpStrategy,
PlacementStrategy, PlacementStrategy,
StrategyType, StrategyType,
) )
from torch.distributed._tensor.ops.utils import is_tensor_partial, register_op_strategy from torch.distributed.tensor._ops.utils import is_tensor_partial, register_op_strategy
from torch.distributed.device_mesh import DeviceMesh
aten = torch.ops.aten aten = torch.ops.aten

View File

@ -4,7 +4,8 @@
from typing import cast, List, Optional, Sequence, Tuple from typing import cast, List, Optional, Sequence, Tuple
import torch import torch
from torch.distributed._tensor._op_schema import ( from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._op_schema import (
_is_inplace_op, _is_inplace_op,
OpSchema, OpSchema,
OpStrategy, OpStrategy,
@ -15,9 +16,9 @@ from torch.distributed._tensor._op_schema import (
StrategyType, StrategyType,
TupleStrategy, TupleStrategy,
) )
from torch.distributed._tensor.ops._common_rules import pointwise_rule from torch.distributed.tensor._ops._common_rules import pointwise_rule
from torch.distributed._tensor.ops._embedding_ops import _MaskPartial from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
from torch.distributed._tensor.ops.utils import ( from torch.distributed.tensor._ops.utils import (
expand_to_full_mesh_op_strategy, expand_to_full_mesh_op_strategy,
is_tensor_dim_sharded, is_tensor_dim_sharded,
is_tensor_evenly_shardable, is_tensor_evenly_shardable,
@ -26,14 +27,13 @@ from torch.distributed._tensor.ops.utils import (
register_op_strategy, register_op_strategy,
register_prop_rule, register_prop_rule,
) )
from torch.distributed._tensor.placement_types import ( from torch.distributed.tensor.placement_types import (
DTensorSpec, DTensorSpec,
Partial, Partial,
Placement, Placement,
Replicate, Replicate,
Shard, Shard,
) )
from torch.distributed.device_mesh import DeviceMesh
aten = torch.ops.aten aten = torch.ops.aten

View File

@ -17,23 +17,27 @@ from typing import (
import torch import torch
from torch import Tensor from torch import Tensor
from torch.distributed._tensor._op_schema import ( from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._op_schema import (
OpSchema, OpSchema,
OpStrategy, OpStrategy,
PlacementStrategy, PlacementStrategy,
RuntimeSchemaInfo, RuntimeSchemaInfo,
StrategyType, StrategyType,
) )
from torch.distributed._tensor.api import Shard from torch.distributed.tensor._ops.utils import (
from torch.distributed._tensor.ops.utils import (
generate_redistribute_costs, generate_redistribute_costs,
normalize_dim, normalize_dim,
normalize_dims, normalize_dims,
prod, prod,
register_op_strategy, register_op_strategy,
) )
from torch.distributed._tensor.placement_types import DTensorSpec, Placement, Replicate from torch.distributed.tensor.placement_types import (
from torch.distributed.device_mesh import DeviceMesh DTensorSpec,
Placement,
Replicate,
Shard,
)
aten = torch.ops.aten aten = torch.ops.aten

View File

@ -6,17 +6,17 @@ import operator
from typing import cast, Iterable, List, Optional, Sequence, Tuple, Union from typing import cast, Iterable, List, Optional, Sequence, Tuple, Union
import torch import torch
from torch.distributed._tensor._collective_utils import redistribute_cost from torch.distributed.tensor._collective_utils import redistribute_cost
from torch.distributed._tensor._op_schema import ( from torch.distributed.tensor._op_schema import (
OpSchema, OpSchema,
OpStrategy, OpStrategy,
PlacementList, PlacementList,
PlacementStrategy, PlacementStrategy,
RuntimeSchemaInfo, RuntimeSchemaInfo,
) )
from torch.distributed._tensor.api import DTensor from torch.distributed.tensor.api import DTensor
from torch.distributed._tensor.device_mesh import DeviceMesh from torch.distributed.tensor.device_mesh import DeviceMesh
from torch.distributed._tensor.placement_types import ( from torch.distributed.tensor.placement_types import (
DTensorSpec, DTensorSpec,
Partial, Partial,
Placement, Placement,

View File

@ -6,9 +6,9 @@ from typing import cast, List, NamedTuple, Tuple
import torch import torch
import torch.distributed._functional_collectives as funcol import torch.distributed._functional_collectives as funcol
import torch.distributed._tensor.api as dtensor import torch.distributed.tensor.api as dtensor
from torch.distributed._tensor.device_mesh import DeviceMesh from torch.distributed.tensor.device_mesh import DeviceMesh
from torch.distributed._tensor.placement_types import ( from torch.distributed.tensor.placement_types import (
DTensorSpec, DTensorSpec,
Partial, Partial,
Placement, Placement,

View File

@ -6,7 +6,8 @@ from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
import torch import torch
from torch._ops import OpOverload from torch._ops import OpOverload
from torch._subclasses import FakeTensorMode from torch._subclasses import FakeTensorMode
from torch.distributed._tensor._op_schema import ( from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._op_schema import (
OpInfo, OpInfo,
OpSchema, OpSchema,
OpStrategy, OpStrategy,
@ -17,13 +18,12 @@ from torch.distributed._tensor._op_schema import (
StrategyType, StrategyType,
TupleStrategy, TupleStrategy,
) )
from torch.distributed._tensor._utils import ( from torch.distributed.tensor._utils import (
compute_local_shape, compute_local_shape,
compute_local_stride, compute_local_stride,
try_find_mesh_from_args, try_find_mesh_from_args,
) )
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta from torch.distributed.tensor.placement_types import DTensorSpec, TensorMeta
from torch.distributed.device_mesh import DeviceMesh
aten = torch.ops.aten aten = torch.ops.aten

View File

@ -5,7 +5,7 @@ from typing import cast, Dict, List, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.distributed._tensor.api as dtensor import torch.distributed.tensor.api as dtensor
aten = torch.ops.aten aten = torch.ops.aten

View File

@ -1,9 +1,10 @@
from typing import cast, List, Sequence, Tuple from typing import cast, List, Sequence, Tuple
import torch import torch
import torch.distributed._tensor.api as dtensor import torch.distributed.tensor.api as dtensor
from torch._prims_common import ShapeType from torch._prims_common import ShapeType
from torch.distributed._tensor.placement_types import ( from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.placement_types import (
_StridedShard, _StridedShard,
DTensorSpec, DTensorSpec,
Partial, Partial,
@ -11,7 +12,6 @@ from torch.distributed._tensor.placement_types import (
Replicate, Replicate,
Shard, Shard,
) )
from torch.distributed.device_mesh import DeviceMesh
# TODO: audit existing code base to see if we can safely remove this API. # TODO: audit existing code base to see if we can safely remove this API.

View File

@ -6,23 +6,21 @@ import warnings
from typing import Any, Callable, cast, Optional, Sequence, Tuple from typing import Any, Callable, cast, Optional, Sequence, Tuple
import torch import torch
import torch.distributed._tensor._dispatch as op_dispatch import torch.distributed.tensor._dispatch as op_dispatch
import torch.distributed._tensor.random as random import torch.distributed.tensor.random as random
import torch.nn as nn import torch.nn as nn
from torch.distributed._tensor._collective_utils import ( from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
check_tensor_meta, from torch.distributed.tensor._collective_utils import check_tensor_meta, mesh_broadcast
mesh_broadcast, from torch.distributed.tensor._redistribute import (
)
from torch.distributed._tensor._redistribute import (
Redistribute, Redistribute,
redistribute_local_tensor, redistribute_local_tensor,
) )
from torch.distributed._tensor._utils import ( from torch.distributed.tensor._utils import (
compute_global_tensor_info, compute_global_tensor_info,
compute_local_shape, compute_local_shape,
normalize_to_torch_size, normalize_to_torch_size,
) )
from torch.distributed._tensor.placement_types import ( from torch.distributed.tensor.placement_types import (
DTensorSpec, DTensorSpec,
Partial, Partial,
Placement, Placement,
@ -30,11 +28,7 @@ from torch.distributed._tensor.placement_types import (
Shard, Shard,
TensorMeta, TensorMeta,
) )
from torch.distributed._tensor.random import ( from torch.distributed.tensor.random import is_rng_supported_mesh, OffsetBasedRNGTracker
is_rng_supported_mesh,
OffsetBasedRNGTracker,
)
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
__all__ = [ __all__ = [
@ -254,7 +248,8 @@ class DTensor(torch.Tensor):
""" """
Construct a DTensor from a local tensor, device mesh, and placement and Construct a DTensor from a local tensor, device mesh, and placement and
other tensor properties (i.e. shape, requires_grad, strides, etc). other tensor properties (i.e. shape, requires_grad, strides, etc).
Note: This is not a public API and it's only supposed to be used by the
.. note:: This is not a public API and it's only supposed to be used by the
operator implementations and internals. If you want to construct a operator implementations and internals. If you want to construct a
DTensor from a local tensor, consider using ``DTensor.from_local``, if DTensor from a local tensor, consider using ``DTensor.from_local``, if
you want to construct a DTensor from a "global" tensor (where you you want to construct a DTensor from a "global" tensor (where you

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from torch.distributed._tensor.debug.comm_mode import CommDebugMode from torch.distributed.tensor.debug.comm_mode import CommDebugMode
from torch.distributed._tensor.debug.visualize_sharding import visualize_sharding from torch.distributed.tensor.debug.visualize_sharding import visualize_sharding
__all__ = ["CommDebugMode", "visualize_sharding"] __all__ = ["CommDebugMode", "visualize_sharding"]
@ -12,7 +12,7 @@ def _get_sharding_prop_cache_info():
This would return a named tuple showing hits, misses, maxsize and cursize of the sharding This would return a named tuple showing hits, misses, maxsize and cursize of the sharding
propagator cache. propagator cache.
""" """
from torch.distributed._tensor.api import DTensor from torch.distributed.tensor.api import DTensor
return ( return (
DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding.cache_info() # type:ignore[attr-defined] DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding.cache_info() # type:ignore[attr-defined]

View File

@ -8,7 +8,7 @@ import torch.nn as nn
from functorch.compile import make_boxed_func from functorch.compile import make_boxed_func
from torch._functorch.compilers import aot_module from torch._functorch.compilers import aot_module
from torch._inductor.decomposition import select_decomp_table from torch._inductor.decomposition import select_decomp_table
from torch.distributed._tensor import DTensor from torch.distributed.tensor import DTensor
inductor_decomps = select_decomp_table() inductor_decomps = select_decomp_table()

View File

@ -7,11 +7,11 @@ from collections import defaultdict
from typing import Any, Dict from typing import Any, Dict
import torch import torch
import torch.distributed._tools.mod_tracker as mod_tracker
import torch.nn import torch.nn
from torch._guards import detect_fake_mode from torch._guards import detect_fake_mode
from torch.autograd.graph import register_multi_grad_hook from torch.autograd.graph import register_multi_grad_hook
from torch.distributed._tensor.api import DTensor from torch.distributed.tensor.api import DTensor
from torch.distributed._tools.mod_tracker import ModTracker
from torch.nn.modules.module import ( from torch.nn.modules.module import (
register_module_forward_hook, register_module_forward_hook,
register_module_forward_pre_hook, register_module_forward_pre_hook,
@ -69,7 +69,7 @@ trivial_ops = {
} }
class CommModeModuleTracker(ModTracker): class _CommModeModuleTracker(mod_tracker.ModTracker):
""" """
Inherits ModuleTracker and expands on its functionality to track the Inherits ModuleTracker and expands on its functionality to track the
parameters and sharding information of a model at a module-level parameters and sharding information of a model at a module-level
@ -250,7 +250,7 @@ class CommDebugMode(TorchDispatchMode):
self.comm_registry.add(py_op) self.comm_registry.add(py_op)
self.comm_registry.add(torch.ops._dtensor.shard_dim_alltoall) self.comm_registry.add(torch.ops._dtensor.shard_dim_alltoall)
self.advanced_module_tracker = CommModeModuleTracker() self.advanced_module_tracker = _CommModeModuleTracker()
def generate_json_dump(self, file_name="comm_mode_log.json", noise_level=3): def generate_json_dump(self, file_name="comm_mode_log.json", noise_level=3):
""" """

View File

@ -4,8 +4,8 @@ from typing import List, Sequence, Tuple
import numpy as np import numpy as np
from torch._prims_common import ShapeType from torch._prims_common import ShapeType
from torch.distributed._tensor import DeviceMesh from torch.distributed.tensor import DeviceMesh
from torch.distributed._tensor.placement_types import Placement, Shard from torch.distributed.tensor.placement_types import Placement, Shard
__all__ = ["visualize_sharding"] __all__ = ["visualize_sharding"]

View File

@ -8,8 +8,8 @@ from typing import Callable, Dict, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.distributed._tensor import DeviceMesh from torch.distributed.tensor import DeviceMesh
from torch.distributed._tensor.debug import CommDebugMode from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.parallel import ( from torch.distributed.tensor.parallel import (
ColwiseParallel, ColwiseParallel,
parallelize_module, parallelize_module,

View File

@ -12,7 +12,7 @@ import time
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.distributed._tensor import ( from torch.distributed.tensor import (
DeviceMesh, DeviceMesh,
distribute_module, distribute_module,
distribute_tensor, distribute_tensor,

View File

@ -9,23 +9,23 @@ from functools import cached_property
from typing import List, TYPE_CHECKING from typing import List, TYPE_CHECKING
import torch import torch
from torch.distributed._tensor import ( from torch.distributed.checkpoint.metadata import (
ChunkStorageMetadata,
TensorProperties,
TensorStorageMetadata,
)
from torch.distributed.tensor import (
DeviceMesh, DeviceMesh,
DTensor, DTensor,
init_device_mesh, init_device_mesh,
Replicate, Replicate,
Shard, Shard,
) )
from torch.distributed._tensor.debug.visualize_sharding import visualize_sharding from torch.distributed.tensor.debug.visualize_sharding import visualize_sharding
from torch.distributed.checkpoint.metadata import (
ChunkStorageMetadata,
TensorProperties,
TensorStorageMetadata,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from torch.distributed._tensor.placement_types import Placement from torch.distributed.tensor.placement_types import Placement
def get_device_type(): def get_device_type():

View File

@ -6,8 +6,8 @@ torchrun --standalone --nnodes=1 --nproc-per-node=4 visualize_sharding_example.p
import os import os
import torch import torch
from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard from torch.distributed.tensor import DeviceMesh, distribute_tensor, Replicate, Shard
from torch.distributed._tensor.debug.visualize_sharding import visualize_sharding from torch.distributed.tensor.debug.visualize_sharding import visualize_sharding
world_size = int(os.environ["WORLD_SIZE"]) world_size = int(os.environ["WORLD_SIZE"])

View File

@ -2,9 +2,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates # Copyright (c) Meta Platforms, Inc. and affiliates
from contextlib import contextmanager from contextlib import contextmanager
from torch.distributed._tensor.api import DTensor from torch.distributed.tensor.api import DTensor
from torch.distributed._tensor.experimental.func_map import local_map from torch.distributed.tensor.experimental.func_map import local_map
from torch.distributed._tensor.experimental.register_sharding import register_sharding from torch.distributed.tensor.experimental.register_sharding import register_sharding
__all__ = ["implicit_replication", "local_map", "register_sharding"] __all__ = ["implicit_replication", "local_map", "register_sharding"]

View File

@ -24,8 +24,8 @@ import torch.distributed as dist
import torch.distributed._functional_collectives as ft_c import torch.distributed._functional_collectives as ft_c
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from torch.distributed._tensor import distribute_module, DTensor, Replicate, Shard
from torch.distributed.device_mesh import DeviceMesh from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor import distribute_module, DTensor, Replicate, Shard
from torch.distributed.tensor.parallel.style import ParallelStyle from torch.distributed.tensor.parallel.style import ParallelStyle

View File

@ -4,8 +4,8 @@ from typing import Callable, Optional, Sequence, Tuple, Union
import torch import torch
from torch.distributed._functional_collectives import AsyncCollectiveTensor from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed._tensor import DeviceMesh, DTensor from torch.distributed.tensor import DeviceMesh, DTensor
from torch.distributed._tensor.placement_types import Placement from torch.distributed.tensor.placement_types import Placement
try: try:
@ -16,7 +16,6 @@ except ImportError:
__all__ = ["local_map"] __all__ = ["local_map"]
PlacementType = Optional[Sequence[Placement]] PlacementType = Optional[Sequence[Placement]]
InputPlacements = Optional[Tuple[PlacementType, ...]] InputPlacements = Optional[Tuple[PlacementType, ...]]
OutputPlacements = Union[PlacementType, Tuple[PlacementType, ...]] OutputPlacements = Union[PlacementType, Tuple[PlacementType, ...]]

View File

@ -5,8 +5,8 @@ from typing import Callable, List, Sequence, Tuple, Union
import torch import torch
from torch._ops import OpOverload from torch._ops import OpOverload
from torch.distributed._tensor import DeviceMesh, DTensor from torch.distributed.tensor import DeviceMesh, DTensor
from torch.distributed._tensor._op_schema import ( from torch.distributed.tensor._op_schema import (
_is_inplace_op, _is_inplace_op,
OpSchema, OpSchema,
OpStrategy, OpStrategy,
@ -15,7 +15,7 @@ from torch.distributed._tensor._op_schema import (
StrategyType, StrategyType,
TupleStrategy, TupleStrategy,
) )
from torch.distributed._tensor.ops.utils import expand_to_full_mesh_op_strategy from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy
__all__ = ["register_sharding"] __all__ = ["register_sharding"]

View File

@ -5,22 +5,22 @@ from typing import Any, cast, Dict, List, Optional, Sequence, Tuple
import torch import torch
from torch._subclasses.fake_tensor import FakeTensor from torch._subclasses.fake_tensor import FakeTensor
from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor
from torch.distributed._tensor._op_schema import ( from torch.distributed.tensor._op_schema import (
DTensorSpec, DTensorSpec,
OpSchema, OpSchema,
OutputSharding, OutputSharding,
OutputSpecType, OutputSpecType,
PlacementStrategy, PlacementStrategy,
) )
from torch.distributed._tensor._redistribute import redistribute_local_tensor from torch.distributed.tensor._redistribute import redistribute_local_tensor
from torch.distributed._tensor.placement_types import ( from torch.distributed.tensor.parallel.style import ColwiseParallel, ParallelStyle
from torch.distributed.tensor.placement_types import (
Placement, Placement,
Replicate, Replicate,
Shard, Shard,
TensorMeta, TensorMeta,
) )
from torch.distributed.tensor.parallel.style import ColwiseParallel, ParallelStyle
from torch.export import ExportedProgram from torch.export import ExportedProgram
from torch.export.exported_program import ExportGraphSignature from torch.export.exported_program import ExportGraphSignature
from torch.fx import GraphModule from torch.fx import GraphModule

View File

@ -3,8 +3,8 @@ from typing import no_type_check, Optional, Tuple
import torch import torch
from torch.distributed._functional_collectives import AsyncCollectiveTensor from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed._tensor import DTensor from torch.distributed.tensor import DTensor
from torch.distributed._tensor.placement_types import DTensorSpec from torch.distributed.tensor.placement_types import DTensorSpec
@no_type_check @no_type_check

View File

@ -2,9 +2,9 @@
import warnings import warnings
from typing import Tuple, Union from typing import Tuple, Union
from torch.distributed._tensor import DeviceMesh
from torch.distributed._tensor.placement_types import Placement
from torch.distributed.device_mesh import _mesh_resources from torch.distributed.device_mesh import _mesh_resources
from torch.distributed.tensor import DeviceMesh
from torch.distributed.tensor.placement_types import Placement
try: try:

View File

@ -3,15 +3,15 @@ from fnmatch import fnmatch
from typing import Dict, Union from typing import Dict, Union
import torch import torch
import torch.distributed._tensor.random as random import torch.distributed.tensor.random as random
import torch.nn as nn import torch.nn as nn
from torch.distributed._tensor import DeviceMesh from torch.distributed.tensor import DeviceMesh
from torch.distributed._tensor.random import ( from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim
from torch.distributed.tensor.parallel.style import ParallelStyle
from torch.distributed.tensor.random import (
is_rng_supported_mesh, is_rng_supported_mesh,
TensorParallelRNGTracker, TensorParallelRNGTracker,
) )
from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim
from torch.distributed.tensor.parallel.style import ParallelStyle
__all__ = [ __all__ = [

View File

@ -14,12 +14,12 @@ from torch.distributed._shard.sharded_tensor import (
) )
from torch.distributed._shard.sharding_spec import ShardMetadata from torch.distributed._shard.sharding_spec import ShardMetadata
from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard as DShard
from torch.distributed.device_mesh import _mesh_resources from torch.distributed.device_mesh import _mesh_resources
from torch.distributed.fsdp._common_utils import _set_fsdp_flattened from torch.distributed.fsdp._common_utils import _set_fsdp_flattened
from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions
from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
from torch.distributed.remote_device import _remote_device from torch.distributed.remote_device import _remote_device
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard
from torch.distributed.tensor.parallel._data_parallel_utils import ( from torch.distributed.tensor.parallel._data_parallel_utils import (
_flatten_tensor, _flatten_tensor,
_unflatten_tensor, _unflatten_tensor,

View File

@ -3,7 +3,7 @@ from functools import partial
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple
import torch import torch
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard
__all__ = [ __all__ = [

View File

@ -8,15 +8,15 @@ import torch._prims_common as utils
import torch.distributed._functional_collectives as funcol import torch.distributed._functional_collectives as funcol
import torch.distributed.distributed_c10d as c10d import torch.distributed.distributed_c10d as c10d
from torch import Tensor from torch import Tensor
from torch.distributed._tensor import DTensor, Replicate, Shard from torch.distributed.device_mesh import DeviceMesh
from torch.distributed._tensor.ops._embedding_ops import _MaskPartial from torch.distributed.tensor import DTensor, Replicate, Shard
from torch.distributed._tensor.ops._math_ops import ( from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
from torch.distributed.tensor._ops._math_ops import (
_skip_dim, _skip_dim,
Reduction, Reduction,
replicate_reduction_dims, replicate_reduction_dims,
) )
from torch.distributed._tensor.placement_types import DTensorSpec, Placement, TensorMeta from torch.distributed.tensor.placement_types import DTensorSpec, Placement, TensorMeta
from torch.distributed.device_mesh import DeviceMesh
aten = torch.ops.aten aten = torch.ops.aten

View File

@ -6,7 +6,7 @@ from typing import Any, Dict, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.distributed._tensor import ( from torch.distributed.tensor import (
DeviceMesh, DeviceMesh,
distribute_module, distribute_module,
distribute_tensor, distribute_tensor,
@ -14,7 +14,7 @@ from torch.distributed._tensor import (
Replicate, Replicate,
Shard, Shard,
) )
from torch.distributed._tensor.placement_types import Placement from torch.distributed.tensor.placement_types import Placement
__all__ = [ __all__ = [

View File

@ -6,7 +6,8 @@ from typing import Any, cast, List, NamedTuple, Optional, Tuple
import torch import torch
import torch.distributed._functional_collectives as funcol import torch.distributed._functional_collectives as funcol
from torch.distributed._tensor._collective_utils import ( from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._collective_utils import (
fill_empty_tensor_to_shards, fill_empty_tensor_to_shards,
mesh_broadcast, mesh_broadcast,
mesh_scatter, mesh_scatter,
@ -14,7 +15,6 @@ from torch.distributed._tensor._collective_utils import (
shard_dim_alltoall, shard_dim_alltoall,
unpad_tensor, unpad_tensor,
) )
from torch.distributed.device_mesh import DeviceMesh
__all__ = ["Placement", "Shard", "Replicate", "Partial", "DTensorSpec", "TensorMeta"] __all__ = ["Placement", "Shard", "Replicate", "Partial", "DTensorSpec", "TensorMeta"]

View File

@ -7,8 +7,8 @@ from typing import Dict, List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch import Tensor from torch import Tensor
from torch.distributed._tensor.placement_types import DTensorSpec, Shard
from torch.distributed.device_mesh import _get_device_handle, DeviceMesh from torch.distributed.device_mesh import _get_device_handle, DeviceMesh
from torch.distributed.tensor.placement_types import DTensorSpec, Shard
__all__ = [ __all__ = [
@ -290,7 +290,7 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
return_offset=False, return_offset=False,
)[0] )[0]
from torch.distributed._tensor.ops.utils import prod from torch.distributed.tensor._ops.utils import prod
local_size = prod(local_size_on_rank_0) local_size = prod(local_size_on_rank_0)
@ -317,7 +317,7 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
""" """
dtensor_shape = spec.shape dtensor_shape = spec.shape
from torch.distributed._tensor.ops.utils import prod from torch.distributed.tensor._ops.utils import prod
numel = prod(dtensor_shape) numel = prod(dtensor_shape)
# pytorch: offset must be multiple of 4 # pytorch: offset must be multiple of 4

View File

@ -34,7 +34,6 @@ from torch.distributed._composable.fsdp._fsdp_param_group import (
FSDPParamGroup, FSDPParamGroup,
RegisterPostBackwardFunction, RegisterPostBackwardFunction,
) )
from torch.distributed._tensor import distribute_tensor, DTensor, Shard
from torch.distributed.device_mesh import DeviceMesh from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel as FSDP from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import TrainingState from torch.distributed.fsdp._common_utils import TrainingState
@ -46,6 +45,7 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import (
) )
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy, wrap from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy, wrap
from torch.distributed.tensor import distribute_tensor, DTensor, Shard
from torch.distributed.tensor.parallel import ( from torch.distributed.tensor.parallel import (
ColwiseParallel, ColwiseParallel,
parallelize_module, parallelize_module,