mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dtensor] move DTensor to public namespace (#133113)
Moving DTensor to be in the public namespace, to formally add the documentation page that includes all the public APIs. This includes: * many path renames and path import fixes * a dedicated doc page without too much content yet (adding in the next PRs) * To preserve the BC for users still using the `torch.distributed._tensor`, I added a shim script to redirect old path calls to the new module The BC preserving is evidented by the fact that all DTensor tests are still working without changing the public imports. So it's safe to land the changes Pull Request resolved: https://github.com/pytorch/pytorch/pull/133113 Approved by: https://github.com/XilunWu ghstack dependencies: #133305, #133306
This commit is contained in:
parent
1a4709cef5
commit
2ee6b97464
|
|
@ -626,6 +626,17 @@ coverage_ignore_functions = [
|
||||||
# torch.distributed.rpc.internal
|
# torch.distributed.rpc.internal
|
||||||
"deserialize",
|
"deserialize",
|
||||||
"serialize",
|
"serialize",
|
||||||
|
# torch.distributed.tensor.api
|
||||||
|
"distribute_module",
|
||||||
|
"distribute_tensor",
|
||||||
|
# torch.distributed.tensor.random
|
||||||
|
"is_rng_supported_mesh",
|
||||||
|
# torch.distributed.tensor.experimental
|
||||||
|
"context_parallel",
|
||||||
|
"local_map",
|
||||||
|
"register_sharding",
|
||||||
|
# torch.distributed.tensor.debug
|
||||||
|
"visualize_sharding",
|
||||||
# torch.distributed.tensor.parallel.api
|
# torch.distributed.tensor.parallel.api
|
||||||
"parallelize_module",
|
"parallelize_module",
|
||||||
# torch.distributed.tensor.parallel.input_reshard
|
# torch.distributed.tensor.parallel.input_reshard
|
||||||
|
|
@ -2621,6 +2632,15 @@ coverage_ignore_classes = [
|
||||||
"RemoteException",
|
"RemoteException",
|
||||||
# torch.distributed.rpc.rref_proxy
|
# torch.distributed.rpc.rref_proxy
|
||||||
"RRefProxy",
|
"RRefProxy",
|
||||||
|
# torch.distributed.tensor.api
|
||||||
|
"DTensor",
|
||||||
|
# torch.distributed.tensor.placement_types
|
||||||
|
"DTensorSpec",
|
||||||
|
"Placement",
|
||||||
|
# torch.distributed.tensor.random
|
||||||
|
"OffsetBasedRNGTracker",
|
||||||
|
# torch.distributed.tensor.debug
|
||||||
|
"CommDebugMode",
|
||||||
# torch.distributed.tensor.parallel.fsdp
|
# torch.distributed.tensor.parallel.fsdp
|
||||||
"DTensorExtensions",
|
"DTensorExtensions",
|
||||||
# torch.distributed.tensor.parallel.style
|
# torch.distributed.tensor.parallel.style
|
||||||
|
|
|
||||||
|
|
@ -876,7 +876,6 @@ If you are running single node training, it may be convenient to interactively b
|
||||||
.. py:module:: torch.distributed.nn.api
|
.. py:module:: torch.distributed.nn.api
|
||||||
.. py:module:: torch.distributed.nn.jit
|
.. py:module:: torch.distributed.nn.jit
|
||||||
.. py:module:: torch.distributed.nn.jit.templates
|
.. py:module:: torch.distributed.nn.jit.templates
|
||||||
.. py:module:: torch.distributed.tensor
|
|
||||||
.. py:module:: torch.distributed.algorithms.ddp_comm_hooks.ddp_zero_hook
|
.. py:module:: torch.distributed.algorithms.ddp_comm_hooks.ddp_zero_hook
|
||||||
.. py:module:: torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks
|
.. py:module:: torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks
|
||||||
.. py:module:: torch.distributed.algorithms.ddp_comm_hooks.default_hooks
|
.. py:module:: torch.distributed.algorithms.ddp_comm_hooks.default_hooks
|
||||||
|
|
|
||||||
104
docs/source/distributed.tensor.rst
Normal file
104
docs/source/distributed.tensor.rst
Normal file
|
|
@ -0,0 +1,104 @@
|
||||||
|
.. role:: hidden
|
||||||
|
:class: hidden-section
|
||||||
|
|
||||||
|
PyTorch DTensor (Distributed Tensor)
|
||||||
|
======================================================
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
``torch.distributed.tensor`` is currently in alpha state and under
|
||||||
|
development, we are committing backward compatibility for the most APIs listed
|
||||||
|
in the doc, but there might be API changes if necessary.
|
||||||
|
|
||||||
|
|
||||||
|
PyTorch DTensor offers simple and flexible tensor sharding primitives that transparently handles distributed
|
||||||
|
logic, including sharded storage, operator computation and collective communications across devices/hosts.
|
||||||
|
``DTensor`` could be used to build different paralleism solutions and support sharded state_dict representation
|
||||||
|
when working with multi-dimensional sharding.
|
||||||
|
|
||||||
|
Please see examples from the PyTorch native parallelism solutions that are built on top of ``DTensor``:
|
||||||
|
|
||||||
|
* `Tensor Parallel <https://pytorch.org/docs/main/distributed.tensor.parallel.html>`__
|
||||||
|
* `FSDP2 <https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md>`__
|
||||||
|
|
||||||
|
.. automodule:: torch.distributed.tensor
|
||||||
|
|
||||||
|
.. currentmodule:: torch.distributed.tensor
|
||||||
|
|
||||||
|
:class:`DTensor` follows the SPMD (single program, multiple data) programming model to empower users to
|
||||||
|
write distributed program as if it's a single-device program with the same convergence property. It
|
||||||
|
provides a uniform tensor sharding layout (DTensor Layout) through specifying the :class:`DeviceMesh`
|
||||||
|
and :class:`Placement`:
|
||||||
|
|
||||||
|
- :class:`DeviceMesh` represents the device topology and the communicators of the cluster using
|
||||||
|
an n-dimensional array.
|
||||||
|
|
||||||
|
- :class:`Placement` describes the sharding layout of the logical tensor on the :class:`DeviceMesh`.
|
||||||
|
DTensor supports three types of placements: :class:`Shard`, :class:`Replicate` and :class:`Partial`.
|
||||||
|
|
||||||
|
There're three ways to construct a :class:`DTensor`:
|
||||||
|
* :meth:`distribute_tensor` creates a :class:`DTensor` from a logical or "global" ``torch.Tensor`` on
|
||||||
|
each rank. This could be used to shard the leaf ``torch.Tensor`` s (i.e. model parameters/buffers
|
||||||
|
and inputs).
|
||||||
|
* :meth:`DTensor.from_local` creates a :class:`DTensor` from a local ``torch.Tensor`` on each rank, which can
|
||||||
|
be used to create :class:`DTensor` from a non-leaf ``torch.Tensor`` s (i.e. intermediate activation
|
||||||
|
tensors during forward/backward).
|
||||||
|
* DTensor provides dedicated tensor factory methods (e.g. :meth:`empty`, :meth:`ones`, :meth:`randn`, etc.)
|
||||||
|
to allow different :class:`DTensor` creations by directly specifying the :class:`DeviceMesh` and
|
||||||
|
:class:`Placement`
|
||||||
|
|
||||||
|
.. autoclass:: DTensor
|
||||||
|
:members:
|
||||||
|
:member-order: bysource
|
||||||
|
|
||||||
|
.. autofunction:: distribute_tensor
|
||||||
|
|
||||||
|
|
||||||
|
Along with :meth:`distribute_tensor`, DTensor also offers a :meth:`distribute_module` API to allow easier
|
||||||
|
sharding on the :class:`nn.Module` level
|
||||||
|
|
||||||
|
.. autofunction:: distribute_module
|
||||||
|
|
||||||
|
DTensor supports the following types of :class:`Placement` on each :class:`DeviceMesh` dimension:
|
||||||
|
|
||||||
|
.. autoclass:: Shard
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
|
||||||
|
.. autoclass:: Replicate
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
|
||||||
|
.. autoclass:: Partial
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
|
||||||
|
DTensor provides dedicated tensor factory functions to allow creating :class:`DTensor` directly
|
||||||
|
using torch.Tensor like factory function APIs (i.e. torch.ones, torch.empty, etc), by additionally
|
||||||
|
specifying the :class:`DeviceMesh` and :class:`Placement` for the :class:`DTensor` created:
|
||||||
|
|
||||||
|
.. autofunction:: zeros
|
||||||
|
|
||||||
|
.. autofunction:: ones
|
||||||
|
|
||||||
|
.. autofunction:: empty
|
||||||
|
|
||||||
|
.. autofunction:: full
|
||||||
|
|
||||||
|
.. autofunction:: rand
|
||||||
|
|
||||||
|
.. autofunction:: randn
|
||||||
|
|
||||||
|
|
||||||
|
.. modules that are missing docs, add the doc later when necessary
|
||||||
|
.. py:module:: torch.distributed.tensor.api
|
||||||
|
.. py:module:: torch.distributed.tensor.device_mesh
|
||||||
|
.. py:module:: torch.distributed.tensor.random
|
||||||
|
.. py:module:: torch.distributed.tensor.placement_types
|
||||||
|
.. py:module:: torch.distributed.tensor.experimental
|
||||||
|
.. py:module:: torch.distributed.tensor.experimental.attention
|
||||||
|
.. py:module:: torch.distributed.tensor.experimental.func_map
|
||||||
|
.. py:module:: torch.distributed.tensor.experimental.register_sharding
|
||||||
|
.. py:module:: torch.distributed.tensor.experimental.tp_transform
|
||||||
|
.. py:module:: torch.distributed.tensor.debug
|
||||||
|
.. py:module:: torch.distributed.tensor.debug.comm_mode
|
||||||
|
.. py:module:: torch.distributed.tensor.debug.visualize_sharding
|
||||||
|
|
@ -74,12 +74,13 @@ Features described in this documentation are classified by release status:
|
||||||
torch.backends <backends>
|
torch.backends <backends>
|
||||||
torch.export <export>
|
torch.export <export>
|
||||||
torch.distributed <distributed>
|
torch.distributed <distributed>
|
||||||
|
torch.distributed.tensor <distributed.tensor>
|
||||||
torch.distributed.algorithms.join <distributed.algorithms.join>
|
torch.distributed.algorithms.join <distributed.algorithms.join>
|
||||||
torch.distributed.elastic <distributed.elastic>
|
torch.distributed.elastic <distributed.elastic>
|
||||||
torch.distributed.fsdp <fsdp>
|
torch.distributed.fsdp <fsdp>
|
||||||
|
torch.distributed.tensor.parallel <distributed.tensor.parallel>
|
||||||
torch.distributed.optim <distributed.optim>
|
torch.distributed.optim <distributed.optim>
|
||||||
torch.distributed.pipelining <distributed.pipelining>
|
torch.distributed.pipelining <distributed.pipelining>
|
||||||
torch.distributed.tensor.parallel <distributed.tensor.parallel>
|
|
||||||
torch.distributed.checkpoint <distributed.checkpoint>
|
torch.distributed.checkpoint <distributed.checkpoint>
|
||||||
torch.distributions <distributions>
|
torch.distributions <distributions>
|
||||||
torch.compiler <torch.compiler>
|
torch.compiler <torch.compiler>
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,8 @@
|
||||||
"torch.nn.quantizable": "torch.ao.nn.quantizable",
|
"torch.nn.quantizable": "torch.ao.nn.quantizable",
|
||||||
"torch.nn.quantizable.modules": "torch.ao.nn.quantizable.modules",
|
"torch.nn.quantizable.modules": "torch.ao.nn.quantizable.modules",
|
||||||
"torch.nn.quantizable.modules.activation": "torch.ao.nn.quantizable.modules.activation",
|
"torch.nn.quantizable.modules.activation": "torch.ao.nn.quantizable.modules.activation",
|
||||||
"torch.nn.quantizable.modules.rnn": "torch.ao.nn.quantizable.modules.rnn"
|
"torch.nn.quantizable.modules.rnn": "torch.ao.nn.quantizable.modules.rnn",
|
||||||
|
"torch.distributed.tensor.device_mesh": "torch.distributed.device_mesh"
|
||||||
},
|
},
|
||||||
"torch.backends": [
|
"torch.backends": [
|
||||||
"contextmanager"
|
"contextmanager"
|
||||||
|
|
@ -231,6 +232,9 @@
|
||||||
"urlunparse"
|
"urlunparse"
|
||||||
],
|
],
|
||||||
"torch.distributed.rpc": [],
|
"torch.distributed.rpc": [],
|
||||||
|
"torch.distributed.tensor": [
|
||||||
|
"DeviceMesh"
|
||||||
|
],
|
||||||
"torch.fft": [
|
"torch.fft": [
|
||||||
"Tensor",
|
"Tensor",
|
||||||
"fft",
|
"fft",
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,9 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed._tensor import DeviceMesh
|
from torch.distributed._tensor import DeviceMesh
|
||||||
from torch.distributed._tensor._op_schema import OpSchema
|
|
||||||
from torch.distributed._tensor.ops._common_rules import einop_rule, pointwise_rule
|
|
||||||
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
|
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
|
||||||
|
from torch.distributed.tensor._op_schema import OpSchema
|
||||||
|
from torch.distributed.tensor._ops._common_rules import einop_rule, pointwise_rule
|
||||||
from torch.testing._internal.common_utils import run_tests
|
from torch.testing._internal.common_utils import run_tests
|
||||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||||
DTensorTestBase,
|
DTensorTestBase,
|
||||||
|
|
|
||||||
|
|
@ -909,7 +909,7 @@ class TestDTensorPlacementTypes(DTensorTestBase):
|
||||||
]
|
]
|
||||||
assert_array_equal(expected_pad_sizes, pad_sizes)
|
assert_array_equal(expected_pad_sizes, pad_sizes)
|
||||||
|
|
||||||
from torch.distributed._tensor._collective_utils import unpad_tensor
|
from torch.distributed.tensor._collective_utils import unpad_tensor
|
||||||
|
|
||||||
unpadded_list = [
|
unpadded_list = [
|
||||||
unpad_tensor(tensor, shard_placement.dim, pad_sizes[i])
|
unpad_tensor(tensor, shard_placement.dim, pad_sizes[i])
|
||||||
|
|
|
||||||
|
|
@ -167,7 +167,7 @@ class TestEmbeddingOp(DTensorTestBase):
|
||||||
self._run_embedding_op_test(mesh, 0, [6, 7, 6], 13, 22)
|
self._run_embedding_op_test(mesh, 0, [6, 7, 6], 13, 22)
|
||||||
self._run_embedding_op_test(mesh, 0, [34], 15, 14, padding_idx=10)
|
self._run_embedding_op_test(mesh, 0, [34], 15, 14, padding_idx=10)
|
||||||
|
|
||||||
from torch.distributed._tensor.ops._embedding_ops import _MaskPartial
|
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
|
||||||
|
|
||||||
# test collectives
|
# test collectives
|
||||||
embedding_mod = torch.nn.Embedding(10, 20, device=self.device_type)
|
embedding_mod = torch.nn.Embedding(10, 20, device=self.device_type)
|
||||||
|
|
@ -191,7 +191,7 @@ class TestEmbeddingOp(DTensorTestBase):
|
||||||
inp = torch.randint(0, 10, (4, 4), device=self.device_type)
|
inp = torch.randint(0, 10, (4, 4), device=self.device_type)
|
||||||
replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False)
|
replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False)
|
||||||
|
|
||||||
from torch.distributed._tensor.ops._embedding_ops import _MaskPartial
|
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
|
||||||
|
|
||||||
# case 1: two embeddings with the same shape, thus sharing the underying _MaskPartial
|
# case 1: two embeddings with the same shape, thus sharing the underying _MaskPartial
|
||||||
# and MaskBuffer, because of cache hit from sharding propagation
|
# and MaskBuffer, because of cache hit from sharding propagation
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,8 @@ import itertools
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed._tensor import DeviceMesh, distribute_module, distribute_tensor
|
from torch.distributed._tensor import DeviceMesh, distribute_module, distribute_tensor
|
||||||
from torch.distributed._tensor.debug import CommDebugMode
|
from torch.distributed._tensor.debug import CommDebugMode
|
||||||
from torch.distributed._tensor.ops.utils import is_tensor_partial, normalize_dim
|
|
||||||
from torch.distributed._tensor.placement_types import Replicate, Shard
|
from torch.distributed._tensor.placement_types import Replicate, Shard
|
||||||
|
from torch.distributed.tensor._ops.utils import is_tensor_partial, normalize_dim
|
||||||
from torch.testing._internal.common_utils import run_tests
|
from torch.testing._internal.common_utils import run_tests
|
||||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||||
DTensorTestBase,
|
DTensorTestBase,
|
||||||
|
|
|
||||||
|
|
@ -4,12 +4,6 @@ from itertools import chain
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed._tensor import DeviceMesh, DTensor
|
from torch.distributed._tensor import DeviceMesh, DTensor
|
||||||
from torch.distributed._tensor._collective_utils import redistribute_cost
|
|
||||||
from torch.distributed._tensor._op_schema import OpSchema, OpStrategy, PlacementStrategy
|
|
||||||
from torch.distributed._tensor.ops._einsum_strategy import (
|
|
||||||
EinsumDims,
|
|
||||||
gen_einsum_strategies,
|
|
||||||
)
|
|
||||||
from torch.distributed._tensor.placement_types import (
|
from torch.distributed._tensor.placement_types import (
|
||||||
DTensorSpec,
|
DTensorSpec,
|
||||||
Partial,
|
Partial,
|
||||||
|
|
@ -17,6 +11,12 @@ from torch.distributed._tensor.placement_types import (
|
||||||
Shard,
|
Shard,
|
||||||
TensorMeta,
|
TensorMeta,
|
||||||
)
|
)
|
||||||
|
from torch.distributed.tensor._collective_utils import redistribute_cost
|
||||||
|
from torch.distributed.tensor._op_schema import OpSchema, OpStrategy, PlacementStrategy
|
||||||
|
from torch.distributed.tensor._ops._einsum_strategy import (
|
||||||
|
EinsumDims,
|
||||||
|
gen_einsum_strategies,
|
||||||
|
)
|
||||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||||
from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase
|
from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase
|
||||||
|
|
||||||
|
|
@ -169,7 +169,7 @@ class TestCostModel(DTensorOpTestBase):
|
||||||
|
|
||||||
def test_redistribute_cost_latency(self):
|
def test_redistribute_cost_latency(self):
|
||||||
# test cost model on addmm op
|
# test cost model on addmm op
|
||||||
from torch.distributed._tensor.ops._matrix_ops import addmm_strategy
|
from torch.distributed.tensor._ops._matrix_ops import addmm_strategy
|
||||||
|
|
||||||
mesh = self.build_device_mesh()
|
mesh = self.build_device_mesh()
|
||||||
shard0_placement = (Shard(0),)
|
shard0_placement = (Shard(0),)
|
||||||
|
|
@ -246,7 +246,7 @@ class TestCostModel(DTensorOpTestBase):
|
||||||
self.assertTrue(allreduce_cost > reduce_scatter_cost)
|
self.assertTrue(allreduce_cost > reduce_scatter_cost)
|
||||||
|
|
||||||
def test_mm_strategies(self):
|
def test_mm_strategies(self):
|
||||||
from torch.distributed._tensor.ops._matrix_ops import mm_strategy
|
from torch.distributed.tensor._ops._matrix_ops import mm_strategy
|
||||||
|
|
||||||
mesh = self.build_device_mesh()
|
mesh = self.build_device_mesh()
|
||||||
lhs_tensor = torch.randn(6, 8)
|
lhs_tensor = torch.randn(6, 8)
|
||||||
|
|
@ -292,7 +292,7 @@ class TestCostModel(DTensorOpTestBase):
|
||||||
self.assertFalse(output_sharding.needs_redistribute)
|
self.assertFalse(output_sharding.needs_redistribute)
|
||||||
|
|
||||||
def test_bmm_strategies(self):
|
def test_bmm_strategies(self):
|
||||||
from torch.distributed._tensor.ops._matrix_ops import bmm_strategy
|
from torch.distributed.tensor._ops._matrix_ops import bmm_strategy
|
||||||
|
|
||||||
mesh = self.build_device_mesh()
|
mesh = self.build_device_mesh()
|
||||||
lhs_tensor = torch.randn(8, 6, 8)
|
lhs_tensor = torch.randn(8, 6, 8)
|
||||||
|
|
|
||||||
|
|
@ -5,10 +5,10 @@ import itertools
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor
|
from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor
|
||||||
from torch.distributed._tensor._collective_utils import shard_dim_alltoall
|
|
||||||
from torch.distributed._tensor.debug import CommDebugMode
|
from torch.distributed._tensor.debug import CommDebugMode
|
||||||
from torch.distributed._tensor.placement_types import Partial, Replicate, Shard
|
from torch.distributed._tensor.placement_types import Partial, Replicate, Shard
|
||||||
from torch.distributed.device_mesh import init_device_mesh
|
from torch.distributed.device_mesh import init_device_mesh
|
||||||
|
from torch.distributed.tensor._collective_utils import shard_dim_alltoall
|
||||||
from torch.testing._internal.common_utils import run_tests
|
from torch.testing._internal.common_utils import run_tests
|
||||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||||
DTensorTestBase,
|
DTensorTestBase,
|
||||||
|
|
@ -207,7 +207,7 @@ class RedistributeTest(DTensorTestBase):
|
||||||
with self.assertRaisesRegex(RuntimeError, "Can not redistribute to Partial"):
|
with self.assertRaisesRegex(RuntimeError, "Can not redistribute to Partial"):
|
||||||
partial_tensor = replica_tensor.redistribute(device_mesh, [partial_spec])
|
partial_tensor = replica_tensor.redistribute(device_mesh, [partial_spec])
|
||||||
|
|
||||||
from torch.distributed._tensor._redistribute import Redistribute
|
from torch.distributed.tensor._redistribute import Redistribute
|
||||||
|
|
||||||
comm_mode = CommDebugMode()
|
comm_mode = CommDebugMode()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -445,7 +445,7 @@ class DistTensorOpsTest(DTensorTestBase):
|
||||||
# case 2 input sharding: input sharded, index replicated, output mask partial
|
# case 2 input sharding: input sharded, index replicated, output mask partial
|
||||||
# only works when index has size 1 on the gather dimension and
|
# only works when index has size 1 on the gather dimension and
|
||||||
# input is sharded on the gather dimension
|
# input is sharded on the gather dimension
|
||||||
from torch.distributed._tensor.ops._embedding_ops import _MaskPartial
|
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
|
||||||
|
|
||||||
gather_dim = 1
|
gather_dim = 1
|
||||||
global_input = torch.randn(12, 8, 16)
|
global_input = torch.randn(12, 8, 16)
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,8 @@ import torch.distributed as dist
|
||||||
from torch import rand, randn, Tensor
|
from torch import rand, randn, Tensor
|
||||||
from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard
|
from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard
|
||||||
from torch.distributed._tensor.debug import CommDebugMode
|
from torch.distributed._tensor.debug import CommDebugMode
|
||||||
from torch.distributed._tensor.ops._view_ops import (
|
from torch.distributed._tensor.placement_types import Placement
|
||||||
|
from torch.distributed.tensor._ops._view_ops import (
|
||||||
Broadcast,
|
Broadcast,
|
||||||
dim_maps,
|
dim_maps,
|
||||||
Flatten,
|
Flatten,
|
||||||
|
|
@ -19,7 +20,6 @@ from torch.distributed._tensor.ops._view_ops import (
|
||||||
Split,
|
Split,
|
||||||
view_groups,
|
view_groups,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor.placement_types import Placement
|
|
||||||
from torch.testing._internal.common_utils import run_tests
|
from torch.testing._internal.common_utils import run_tests
|
||||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||||
DTensorTestBase,
|
DTensorTestBase,
|
||||||
|
|
|
||||||
|
|
@ -5,11 +5,6 @@ import os
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed._functional_collectives as funcol
|
import torch.distributed._functional_collectives as funcol
|
||||||
from torch.distributed._tensor import DTensor
|
from torch.distributed._tensor import DTensor
|
||||||
from torch.distributed._tensor._collective_utils import (
|
|
||||||
mesh_broadcast,
|
|
||||||
mesh_scatter,
|
|
||||||
unpad_tensor,
|
|
||||||
)
|
|
||||||
from torch.distributed._tensor.placement_types import _Partial, Shard
|
from torch.distributed._tensor.placement_types import _Partial, Shard
|
||||||
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh
|
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh
|
||||||
from torch.distributed.distributed_c10d import (
|
from torch.distributed.distributed_c10d import (
|
||||||
|
|
@ -22,6 +17,11 @@ from torch.distributed.distributed_c10d import (
|
||||||
is_nccl_available,
|
is_nccl_available,
|
||||||
ProcessGroup,
|
ProcessGroup,
|
||||||
)
|
)
|
||||||
|
from torch.distributed.tensor._collective_utils import (
|
||||||
|
mesh_broadcast,
|
||||||
|
mesh_scatter,
|
||||||
|
unpad_tensor,
|
||||||
|
)
|
||||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||||
from torch.testing._internal.common_utils import run_tests
|
from torch.testing._internal.common_utils import run_tests
|
||||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||||
|
|
|
||||||
|
|
@ -1432,12 +1432,12 @@ class GuardBuilder(GuardBuilderBase):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if torch.distributed.is_available():
|
if torch.distributed.is_available():
|
||||||
from torch.distributed._tensor.placement_types import (
|
from torch.distributed.device_mesh import DeviceMesh
|
||||||
|
from torch.distributed.tensor.placement_types import (
|
||||||
Partial,
|
Partial,
|
||||||
Replicate,
|
Replicate,
|
||||||
Shard,
|
Shard,
|
||||||
)
|
)
|
||||||
from torch.distributed.device_mesh import DeviceMesh
|
|
||||||
|
|
||||||
ok_types = ok_types + (
|
ok_types = ok_types + (
|
||||||
Shard,
|
Shard,
|
||||||
|
|
|
||||||
|
|
@ -142,7 +142,7 @@ manual_torch_name_rule_map = {
|
||||||
"torch.distributed.is_initialized": TorchInGraphFunctionVariable,
|
"torch.distributed.is_initialized": TorchInGraphFunctionVariable,
|
||||||
"torch.distributed.get_rank": TorchInGraphFunctionVariable,
|
"torch.distributed.get_rank": TorchInGraphFunctionVariable,
|
||||||
"torch.distributed.get_world_size": TorchInGraphFunctionVariable,
|
"torch.distributed.get_world_size": TorchInGraphFunctionVariable,
|
||||||
"torch.distributed._tensor.api.DTensor#from_local": TorchInGraphFunctionVariable,
|
"torch.distributed.tensor.api.DTensor#from_local": TorchInGraphFunctionVariable,
|
||||||
"torch.distributed.distributed_c10d._get_group_size_by_name": TorchInGraphFunctionVariable,
|
"torch.distributed.distributed_c10d._get_group_size_by_name": TorchInGraphFunctionVariable,
|
||||||
"torch.distributed.distributed_c10d._resolve_group_name_by_ranks_and_tag": TorchInGraphFunctionVariable,
|
"torch.distributed.distributed_c10d._resolve_group_name_by_ranks_and_tag": TorchInGraphFunctionVariable,
|
||||||
"torch.distributed.distributed_c10d._get_group_tag": TorchInGraphFunctionVariable,
|
"torch.distributed.distributed_c10d._get_group_tag": TorchInGraphFunctionVariable,
|
||||||
|
|
@ -3190,8 +3190,8 @@ LEGACY_MOD_INLINELIST = {
|
||||||
|
|
||||||
if torch.distributed.is_available():
|
if torch.distributed.is_available():
|
||||||
LEGACY_MOD_INLINELIST |= {
|
LEGACY_MOD_INLINELIST |= {
|
||||||
"torch.distributed._tensor.api",
|
"torch.distributed.tensor.api",
|
||||||
"torch.distributed._tensor.device_mesh",
|
"torch.distributed.tensor.device_mesh",
|
||||||
"torch.distributed.device_mesh",
|
"torch.distributed.device_mesh",
|
||||||
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper",
|
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper",
|
||||||
"torch.distributed.tensor.parallel._data_parallel_utils",
|
"torch.distributed.tensor.parallel._data_parallel_utils",
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ class DistributedVariable(VariableTracker):
|
||||||
def is_from_local(value):
|
def is_from_local(value):
|
||||||
if not DistributedVariable.is_available():
|
if not DistributedVariable.is_available():
|
||||||
return False
|
return False
|
||||||
from torch.distributed._tensor import DTensor
|
from torch.distributed.tensor import DTensor
|
||||||
|
|
||||||
return inspect.isfunction(value) and value is DTensor.from_local
|
return inspect.isfunction(value) and value is DTensor.from_local
|
||||||
|
|
||||||
|
|
@ -108,7 +108,7 @@ class PlacementClassVariable(DistributedVariable):
|
||||||
if not DistributedVariable.is_available():
|
if not DistributedVariable.is_available():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
from torch.distributed._tensor.placement_types import Placement
|
from torch.distributed.tensor.placement_types import Placement
|
||||||
|
|
||||||
return type(value) is type and issubclass(value, Placement)
|
return type(value) is type and issubclass(value, Placement)
|
||||||
|
|
||||||
|
|
@ -143,7 +143,7 @@ class PlacementVariable(DistributedVariable):
|
||||||
if not DistributedVariable.is_available():
|
if not DistributedVariable.is_available():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
from torch.distributed._tensor.placement_types import Placement
|
from torch.distributed.tensor.placement_types import Placement
|
||||||
|
|
||||||
return isinstance(value, Placement)
|
return isinstance(value, Placement)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -598,7 +598,6 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||||
)
|
)
|
||||||
|
|
||||||
if DistributedVariable.is_available():
|
if DistributedVariable.is_available():
|
||||||
from torch.distributed._tensor import DTensor
|
|
||||||
from torch.distributed.distributed_c10d import (
|
from torch.distributed.distributed_c10d import (
|
||||||
_get_group_size_by_name,
|
_get_group_size_by_name,
|
||||||
_get_group_tag,
|
_get_group_tag,
|
||||||
|
|
@ -606,6 +605,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||||
_resolve_group_name_by_ranks_and_tag,
|
_resolve_group_name_by_ranks_and_tag,
|
||||||
get_process_group_ranks,
|
get_process_group_ranks,
|
||||||
)
|
)
|
||||||
|
from torch.distributed.tensor import DTensor
|
||||||
|
|
||||||
@register(
|
@register(
|
||||||
_get_group_size_by_name,
|
_get_group_size_by_name,
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,8 @@ from typing import cast, List, NamedTuple, Optional, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
import torch._dynamo.compiled_autograd as ca
|
import torch._dynamo.compiled_autograd as ca
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed._tensor import DTensor
|
|
||||||
from torch.distributed.distributed_c10d import ReduceOp
|
from torch.distributed.distributed_c10d import ReduceOp
|
||||||
|
from torch.distributed.tensor import DTensor
|
||||||
|
|
||||||
from ._fsdp_common import (
|
from ._fsdp_common import (
|
||||||
_get_dim0_padded_size,
|
_get_dim0_padded_size,
|
||||||
|
|
|
||||||
|
|
@ -10,8 +10,8 @@ import torch._dynamo.compiled_autograd as ca
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed._composable.contract import _get_registry
|
from torch.distributed._composable.contract import _get_registry
|
||||||
from torch.distributed._tensor import DeviceMesh, DTensor
|
from torch.distributed.tensor import DeviceMesh, DTensor
|
||||||
from torch.distributed._tensor.placement_types import DTensorSpec
|
from torch.distributed.tensor.placement_types import DTensorSpec
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,8 @@ from typing import List, Optional, Set, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed._tensor import DeviceMesh, DTensor, init_device_mesh
|
|
||||||
from torch.distributed.device_mesh import _get_device_handle
|
from torch.distributed.device_mesh import _get_device_handle
|
||||||
|
from torch.distributed.tensor import DeviceMesh, DTensor, init_device_mesh
|
||||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||||
|
|
||||||
from ._fsdp_common import _is_composable_with_fsdp, FSDPMeshInfo, HSDPMeshInfo
|
from ._fsdp_common import _is_composable_with_fsdp, FSDPMeshInfo, HSDPMeshInfo
|
||||||
|
|
|
||||||
|
|
@ -9,9 +9,9 @@ import torch._dynamo.compiled_autograd as ca
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch._prims_common import make_contiguous_strides_for
|
from torch._prims_common import make_contiguous_strides_for
|
||||||
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
||||||
from torch.distributed._tensor import DTensor, Replicate, Shard
|
from torch.distributed.tensor import DTensor, Replicate, Shard
|
||||||
from torch.distributed._tensor.device_mesh import _mesh_resources
|
from torch.distributed.tensor.device_mesh import _mesh_resources
|
||||||
from torch.distributed._tensor.placement_types import (
|
from torch.distributed.tensor.placement_types import (
|
||||||
_StridedShard,
|
_StridedShard,
|
||||||
DTensorSpec,
|
DTensorSpec,
|
||||||
Placement,
|
Placement,
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from typing import Any, cast, Iterable, List, NoReturn, Optional, Union
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed._composable import contract
|
from torch.distributed._composable import contract
|
||||||
from torch.distributed._tensor import DeviceMesh
|
from torch.distributed.tensor import DeviceMesh
|
||||||
from torch.distributed.utils import _get_root_modules
|
from torch.distributed.utils import _get_root_modules
|
||||||
|
|
||||||
from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy
|
from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy
|
||||||
|
|
|
||||||
|
|
@ -97,7 +97,7 @@ RANK_TYPES = Union[
|
||||||
List[List[int]],
|
List[List[int]],
|
||||||
dist.ProcessGroup,
|
dist.ProcessGroup,
|
||||||
DeviceMesh,
|
DeviceMesh,
|
||||||
Tuple["dist._tensor.DeviceMesh", int],
|
Tuple["dist.tensor.DeviceMesh", int],
|
||||||
str,
|
str,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
||||||
if dist.is_available() or TYPE_CHECKING:
|
if dist.is_available() or TYPE_CHECKING:
|
||||||
from torch.distributed import distributed_c10d
|
from torch.distributed import distributed_c10d
|
||||||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||||
from torch.distributed._tensor import distribute_tensor, DTensor, Replicate
|
from torch.distributed.tensor import distribute_tensor, DTensor, Replicate
|
||||||
|
|
||||||
|
|
||||||
def _identity_func(
|
def _identity_func(
|
||||||
|
|
|
||||||
|
|
@ -1,58 +1,43 @@
|
||||||
# mypy: allow-untyped-defs
|
"""
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
NOTICE: DTensor has moved to torch.distributed.tensor
|
||||||
|
|
||||||
import torch
|
This file is a shim to redirect to the new location, and
|
||||||
import torch.distributed._tensor.ops as _ops # force import all built-in dtensor ops
|
we keep the old import path starts with `_tensor` for
|
||||||
from torch.distributed._tensor.api import (
|
backward compatibility.
|
||||||
distribute_module,
|
"""
|
||||||
distribute_tensor,
|
import importlib
|
||||||
DTensor,
|
import sys
|
||||||
empty,
|
|
||||||
full,
|
import torch.distributed.tensor
|
||||||
ones,
|
|
||||||
rand,
|
|
||||||
randn,
|
|
||||||
zeros,
|
|
||||||
)
|
|
||||||
from torch.distributed._tensor.placement_types import (
|
|
||||||
Partial,
|
|
||||||
Placement,
|
|
||||||
Replicate,
|
|
||||||
Shard,
|
|
||||||
)
|
|
||||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
|
||||||
from torch.optim.optimizer import (
|
|
||||||
_foreach_supported_types as _optim_foreach_supported_types,
|
|
||||||
)
|
|
||||||
from torch.utils._foreach_utils import (
|
|
||||||
_foreach_supported_types as _util_foreach_supported_types,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# All public APIs from dtensor package
|
def _populate(): # type: ignore[no-untyped-def]
|
||||||
__all__ = [
|
for name in (
|
||||||
"DTensor",
|
# TODO: _utils here mainly for checkpoint imports BC, remove it
|
||||||
"DeviceMesh",
|
"_utils",
|
||||||
"distribute_tensor",
|
"api",
|
||||||
"distribute_module",
|
"debug",
|
||||||
"init_device_mesh,",
|
"device_mesh",
|
||||||
"Shard",
|
"experimental",
|
||||||
"Replicate",
|
"placement_types",
|
||||||
"Partial",
|
"random",
|
||||||
"Placement",
|
):
|
||||||
"ones",
|
try:
|
||||||
"empty",
|
globals()[name] = sys.modules[
|
||||||
"full",
|
f"torch.distributed._tensor.{name}"
|
||||||
"rand",
|
] = importlib.import_module(f"torch.distributed.tensor.{name}")
|
||||||
"randn",
|
except ImportError as e:
|
||||||
"zeros",
|
import traceback
|
||||||
]
|
|
||||||
|
traceback.print_exc()
|
||||||
|
raise ImportError(
|
||||||
|
f"Failed to import torch.distributed.tensor.{name} due to {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
for name, val in torch.distributed.tensor.__dict__.items():
|
||||||
|
# Skip private names and tensor parallel package
|
||||||
|
if not name.startswith("_") and name != "parallel":
|
||||||
|
globals()[name] = val
|
||||||
|
|
||||||
|
|
||||||
# Append DTensor to the list of supported types for foreach implementation for optimizer
|
_populate()
|
||||||
# and clip_grad_norm_ so that we will try to use foreach over the for-loop implementation on CUDA.
|
|
||||||
if DTensor not in _optim_foreach_supported_types:
|
|
||||||
_optim_foreach_supported_types.append(DTensor)
|
|
||||||
|
|
||||||
if DTensor not in _util_foreach_supported_types:
|
|
||||||
_util_foreach_supported_types.append(DTensor)
|
|
||||||
|
|
|
||||||
|
|
@ -14,8 +14,8 @@ from typing import (
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
|
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
|
||||||
from torch.distributed._tensor import DTensor
|
|
||||||
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
|
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
|
||||||
|
from torch.distributed.tensor import DTensor
|
||||||
|
|
||||||
|
|
||||||
PATH_ITEM = Union[str, int]
|
PATH_ITEM = Union[str, int]
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@ from typing import Any, cast, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed._shard._utils import narrow_tensor_by_index
|
from torch.distributed._shard._utils import narrow_tensor_by_index
|
||||||
from torch.distributed._tensor import DTensor
|
|
||||||
from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans
|
from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans
|
||||||
from torch.distributed.checkpoint._nested_dict import (
|
from torch.distributed.checkpoint._nested_dict import (
|
||||||
FLATTEN_MAPPING,
|
FLATTEN_MAPPING,
|
||||||
|
|
@ -45,6 +44,7 @@ from torch.distributed.checkpoint.planner_helpers import (
|
||||||
_init_state_dict,
|
_init_state_dict,
|
||||||
)
|
)
|
||||||
from torch.distributed.checkpoint.utils import find_state_dict_object
|
from torch.distributed.checkpoint.utils import find_state_dict_object
|
||||||
|
from torch.distributed.tensor import DTensor
|
||||||
|
|
||||||
|
|
||||||
logger: logging.Logger = logging.getLogger(__name__)
|
logger: logging.Logger = logging.getLogger(__name__)
|
||||||
|
|
|
||||||
|
|
@ -11,12 +11,12 @@ import torch.distributed.checkpoint as dcp
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.distributed._tensor.device_mesh import init_device_mesh
|
|
||||||
from torch.distributed.checkpoint.state_dict import (
|
from torch.distributed.checkpoint.state_dict import (
|
||||||
_patch_model_state_dict,
|
_patch_model_state_dict,
|
||||||
_patch_optimizer_state_dict,
|
_patch_optimizer_state_dict,
|
||||||
)
|
)
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
|
from torch.distributed.tensor.device_mesh import init_device_mesh
|
||||||
|
|
||||||
|
|
||||||
DEVICE = "cuda"
|
DEVICE = "cuda"
|
||||||
|
|
|
||||||
|
|
@ -12,11 +12,11 @@ import torch.distributed as dist
|
||||||
import torch.distributed.checkpoint as dcp
|
import torch.distributed.checkpoint as dcp
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed._tensor.device_mesh import init_device_mesh
|
|
||||||
from torch.distributed.checkpoint.state_dict import (
|
from torch.distributed.checkpoint.state_dict import (
|
||||||
_patch_model_state_dict,
|
_patch_model_state_dict,
|
||||||
_patch_optimizer_state_dict,
|
_patch_optimizer_state_dict,
|
||||||
)
|
)
|
||||||
|
from torch.distributed.device_mesh import init_device_mesh
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,6 @@ from torch.distributed._shard.sharded_tensor.metadata import (
|
||||||
)
|
)
|
||||||
from torch.distributed._shard.sharded_tensor.shard import Shard
|
from torch.distributed._shard.sharded_tensor.shard import Shard
|
||||||
from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec
|
from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec
|
||||||
from torch.distributed._tensor import DTensor
|
|
||||||
from torch.distributed.checkpoint._nested_dict import unflatten_state_dict
|
from torch.distributed.checkpoint._nested_dict import unflatten_state_dict
|
||||||
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
|
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
|
||||||
from torch.distributed.checkpoint.metadata import (
|
from torch.distributed.checkpoint.metadata import (
|
||||||
|
|
@ -39,6 +38,7 @@ from torch.distributed.checkpoint.utils import (
|
||||||
from torch.distributed.distributed_c10d import _get_default_group
|
from torch.distributed.distributed_c10d import _get_default_group
|
||||||
from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
|
from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
|
||||||
from torch.distributed.remote_device import _remote_device
|
from torch.distributed.remote_device import _remote_device
|
||||||
|
from torch.distributed.tensor import DTensor
|
||||||
|
|
||||||
|
|
||||||
STATE_DICT_2D_LAYOUT = Dict[str, Tuple[Optional[Sequence[int]], Sequence[int]]]
|
STATE_DICT_2D_LAYOUT = Dict[str, Tuple[Optional[Sequence[int]], Sequence[int]]]
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,8 @@ import torch.distributed as dist
|
||||||
from torch._utils import _get_device_module
|
from torch._utils import _get_device_module
|
||||||
from torch.distributed._shard.metadata import ShardMetadata
|
from torch.distributed._shard.metadata import ShardMetadata
|
||||||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||||
from torch.distributed._tensor import DTensor
|
from torch.distributed.tensor import DTensor
|
||||||
from torch.distributed._tensor._utils import compute_local_shape_and_global_offset
|
from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
|
||||||
|
|
||||||
from .metadata import (
|
from .metadata import (
|
||||||
BytesStorageMetadata,
|
BytesStorageMetadata,
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,6 @@ from torch.distributed._state_dict_utils import (
|
||||||
_offload_state_dict_to_cpu,
|
_offload_state_dict_to_cpu,
|
||||||
_unflatten_state_dict,
|
_unflatten_state_dict,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor import DTensor
|
|
||||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||||
_CHECKPOINT_PREFIX,
|
_CHECKPOINT_PREFIX,
|
||||||
)
|
)
|
||||||
|
|
@ -50,6 +49,7 @@ from torch.distributed.fsdp._common_utils import (
|
||||||
_get_module_fsdp_state_if_fully_sharded_module,
|
_get_module_fsdp_state_if_fully_sharded_module,
|
||||||
FSDP_WRAPPED_MODULE,
|
FSDP_WRAPPED_MODULE,
|
||||||
)
|
)
|
||||||
|
from torch.distributed.tensor import DTensor
|
||||||
from torch.nn.modules.module import _IncompatibleKeys
|
from torch.nn.modules.module import _IncompatibleKeys
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils._pytree import tree_map_only
|
from torch.utils._pytree import tree_map_only
|
||||||
|
|
|
||||||
|
|
@ -1887,7 +1887,7 @@ class FlatParamHandle:
|
||||||
flat_param = self.flat_param
|
flat_param = self.flat_param
|
||||||
self._check_unsharded(flat_param)
|
self._check_unsharded(flat_param)
|
||||||
views = self._get_unflat_views()
|
views = self._get_unflat_views()
|
||||||
from torch.distributed._tensor import DTensor
|
from torch.distributed.tensor import DTensor
|
||||||
|
|
||||||
for i, (view, (param_name, module, _)) in enumerate(
|
for i, (view, (param_name, module, _)) in enumerate(
|
||||||
zip(views, flat_param._param_infos)
|
zip(views, flat_param._param_infos)
|
||||||
|
|
@ -2717,7 +2717,7 @@ def _warn_use_fake_reduce(log: logging.Logger, warning: str):
|
||||||
def _same_storage(a, b):
|
def _same_storage(a, b):
|
||||||
# Params are DTensors in backward
|
# Params are DTensors in backward
|
||||||
# with SHARD_GRAD_OP + TP
|
# with SHARD_GRAD_OP + TP
|
||||||
from torch.distributed._tensor import DTensor
|
from torch.distributed.tensor import DTensor
|
||||||
|
|
||||||
if isinstance(a, DTensor):
|
if isinstance(a, DTensor):
|
||||||
a = a._local_tensor
|
a = a._local_tensor
|
||||||
|
|
|
||||||
|
|
@ -5,12 +5,12 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
|
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
|
||||||
from torch.distributed._shard.sharded_tensor.shard import Shard
|
from torch.distributed._shard.sharded_tensor.shard import Shard
|
||||||
from torch.distributed._tensor import DeviceMesh, DTensor
|
|
||||||
from torch.distributed.fsdp._shard_utils import (
|
from torch.distributed.fsdp._shard_utils import (
|
||||||
_all_gather_dtensor,
|
_all_gather_dtensor,
|
||||||
_create_chunk_dtensor,
|
_create_chunk_dtensor,
|
||||||
_create_chunk_sharded_tensor,
|
_create_chunk_sharded_tensor,
|
||||||
)
|
)
|
||||||
|
from torch.distributed.tensor import DeviceMesh, DTensor
|
||||||
|
|
||||||
|
|
||||||
class FSDPExtensions(ABC):
|
class FSDPExtensions(ABC):
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,6 @@ import torch.distributed as dist
|
||||||
import torch.distributed.fsdp._traversal_utils as traversal_utils
|
import torch.distributed.fsdp._traversal_utils as traversal_utils
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed._state_dict_utils import _gather_state_dict
|
from torch.distributed._state_dict_utils import _gather_state_dict
|
||||||
from torch.distributed._tensor import DTensor, Replicate
|
|
||||||
from torch.distributed.distributed_c10d import _get_pg_default_device
|
from torch.distributed.distributed_c10d import _get_pg_default_device
|
||||||
from torch.distributed.fsdp._common_utils import (
|
from torch.distributed.fsdp._common_utils import (
|
||||||
_apply_to_modules,
|
_apply_to_modules,
|
||||||
|
|
@ -53,6 +52,7 @@ from torch.distributed.fsdp.api import (
|
||||||
StateDictSettings,
|
StateDictSettings,
|
||||||
StateDictType,
|
StateDictType,
|
||||||
)
|
)
|
||||||
|
from torch.distributed.tensor import DTensor, Replicate
|
||||||
from torch.utils._pytree import tree_map_only
|
from torch.utils._pytree import tree_map_only
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from torch.distributed._shard.sharded_tensor import (
|
||||||
TensorProperties,
|
TensorProperties,
|
||||||
)
|
)
|
||||||
from torch.distributed._shard.sharding_spec import ShardMetadata
|
from torch.distributed._shard.sharding_spec import ShardMetadata
|
||||||
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard as DShard
|
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard
|
||||||
|
|
||||||
|
|
||||||
def _get_remote_device_str(rank, device_type, num_devices_per_node):
|
def _get_remote_device_str(rank, device_type, num_devices_per_node):
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,6 @@ from torch.distributed._shard.sharded_tensor import (
|
||||||
Shard,
|
Shard,
|
||||||
ShardedTensor,
|
ShardedTensor,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor import DTensor
|
|
||||||
from torch.distributed.device_mesh import _mesh_resources
|
from torch.distributed.device_mesh import _mesh_resources
|
||||||
from torch.distributed.fsdp._common_utils import (
|
from torch.distributed.fsdp._common_utils import (
|
||||||
_FSDPState,
|
_FSDPState,
|
||||||
|
|
@ -49,6 +48,7 @@ from torch.distributed.fsdp.api import (
|
||||||
ShardingStrategy,
|
ShardingStrategy,
|
||||||
StateDictType,
|
StateDictType,
|
||||||
)
|
)
|
||||||
|
from torch.distributed.tensor import DTensor
|
||||||
from torch.distributed.utils import _replace_by_prefix
|
from torch.distributed.utils import _replace_by_prefix
|
||||||
|
|
||||||
from ._fsdp_extensions import (
|
from ._fsdp_extensions import (
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,6 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.distributed.fsdp._traversal_utils as traversal_utils
|
import torch.distributed.fsdp._traversal_utils as traversal_utils
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed._tensor import DeviceMesh
|
|
||||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||||
_CHECKPOINT_WRAPPED_MODULE,
|
_CHECKPOINT_WRAPPED_MODULE,
|
||||||
ActivationWrapper,
|
ActivationWrapper,
|
||||||
|
|
@ -84,6 +83,7 @@ from torch.distributed.fsdp.api import (
|
||||||
StateDictSettings,
|
StateDictSettings,
|
||||||
StateDictType,
|
StateDictType,
|
||||||
)
|
)
|
||||||
|
from torch.distributed.tensor import DeviceMesh
|
||||||
from torch.distributed.utils import _p_assert
|
from torch.distributed.utils import _p_assert
|
||||||
|
|
||||||
from ._flat_param import FlatParameter, FlatParamHandle
|
from ._flat_param import FlatParameter, FlatParamHandle
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,57 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed.tensor._ops # force import all built-in dtensor ops
|
||||||
|
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||||
|
from torch.distributed.tensor.api import (
|
||||||
|
distribute_module,
|
||||||
|
distribute_tensor,
|
||||||
|
DTensor,
|
||||||
|
empty,
|
||||||
|
full,
|
||||||
|
ones,
|
||||||
|
rand,
|
||||||
|
randn,
|
||||||
|
zeros,
|
||||||
|
)
|
||||||
|
from torch.distributed.tensor.placement_types import (
|
||||||
|
Partial,
|
||||||
|
Placement,
|
||||||
|
Replicate,
|
||||||
|
Shard,
|
||||||
|
)
|
||||||
|
from torch.optim.optimizer import (
|
||||||
|
_foreach_supported_types as _optim_foreach_supported_types,
|
||||||
|
)
|
||||||
|
from torch.utils._foreach_utils import (
|
||||||
|
_foreach_supported_types as _util_foreach_supported_types,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# All public APIs from dtensor package
|
||||||
|
__all__ = [
|
||||||
|
"DTensor",
|
||||||
|
"DeviceMesh",
|
||||||
|
"distribute_tensor",
|
||||||
|
"distribute_module",
|
||||||
|
"init_device_mesh,",
|
||||||
|
"Shard",
|
||||||
|
"Replicate",
|
||||||
|
"Partial",
|
||||||
|
"Placement",
|
||||||
|
"ones",
|
||||||
|
"empty",
|
||||||
|
"full",
|
||||||
|
"rand",
|
||||||
|
"randn",
|
||||||
|
"zeros",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Append DTensor to the list of supported types for foreach implementation for optimizer
|
||||||
|
# and clip_grad_norm_ so that we will try to use foreach over the for-loop implementation on CUDA.
|
||||||
|
if DTensor not in _optim_foreach_supported_types:
|
||||||
|
_optim_foreach_supported_types.append(DTensor)
|
||||||
|
|
||||||
|
if DTensor not in _util_foreach_supported_types:
|
||||||
|
_util_foreach_supported_types.append(DTensor)
|
||||||
|
|
@ -7,7 +7,7 @@ from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed._functional_collectives as funcol
|
import torch.distributed._functional_collectives as funcol
|
||||||
import torch.distributed._tensor.placement_types as placement_types
|
import torch.distributed.tensor.placement_types as placement_types
|
||||||
from torch._C._distributed_c10d import _resolve_process_group
|
from torch._C._distributed_c10d import _resolve_process_group
|
||||||
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
|
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
|
||||||
from torch.distributed.distributed_c10d import (
|
from torch.distributed.distributed_c10d import (
|
||||||
|
|
@ -8,24 +8,24 @@ from typing import cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.distributed._tensor.api as dtensor
|
import torch.distributed.tensor.api as dtensor
|
||||||
import torch.distributed._tensor.random as random
|
import torch.distributed.tensor.random as random
|
||||||
from torch.distributed._tensor._op_schema import (
|
from torch.distributed.tensor._op_schema import (
|
||||||
_is_inplace_op,
|
_is_inplace_op,
|
||||||
_is_out_variant_op,
|
_is_out_variant_op,
|
||||||
OpInfo,
|
OpInfo,
|
||||||
OpSchema,
|
OpSchema,
|
||||||
OutputSpecType,
|
OutputSpecType,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor._redistribute import redistribute_local_tensor
|
from torch.distributed.tensor._redistribute import redistribute_local_tensor
|
||||||
from torch.distributed._tensor._sharding_prop import ShardingPropagator
|
from torch.distributed.tensor._sharding_prop import ShardingPropagator
|
||||||
from torch.distributed._tensor._tp_conv import (
|
from torch.distributed.tensor._tp_conv import (
|
||||||
convolution_backward_handler,
|
convolution_backward_handler,
|
||||||
convolution_handler,
|
convolution_handler,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor._utils import try_find_mesh_from_args
|
from torch.distributed.tensor._utils import try_find_mesh_from_args
|
||||||
from torch.distributed._tensor.placement_types import DTensorSpec, Replicate, TensorMeta
|
from torch.distributed.tensor.placement_types import DTensorSpec, Replicate, TensorMeta
|
||||||
from torch.distributed._tensor.random import is_rng_supported_mesh
|
from torch.distributed.tensor.random import is_rng_supported_mesh
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
@ -5,8 +5,8 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._ops import OpOverload
|
from torch._ops import OpOverload
|
||||||
from torch.distributed._tensor.placement_types import DTensorSpec, Placement
|
|
||||||
from torch.distributed.device_mesh import DeviceMesh
|
from torch.distributed.device_mesh import DeviceMesh
|
||||||
|
from torch.distributed.tensor.placement_types import DTensorSpec, Placement
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -2,15 +2,15 @@
|
||||||
from typing import cast, Dict, List, Optional, Tuple
|
from typing import cast, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed._tensor._op_schema import (
|
from torch.distributed.tensor._op_schema import (
|
||||||
_is_inplace_op,
|
_is_inplace_op,
|
||||||
_is_out_variant_op,
|
_is_out_variant_op,
|
||||||
OpSchema,
|
OpSchema,
|
||||||
OutputSharding,
|
OutputSharding,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor._utils import compute_local_shape
|
from torch.distributed.tensor._ops.utils import prod
|
||||||
from torch.distributed._tensor.ops.utils import prod
|
from torch.distributed.tensor._utils import compute_local_shape
|
||||||
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
|
from torch.distributed.tensor.placement_types import DTensorSpec, TensorMeta
|
||||||
|
|
||||||
|
|
||||||
def _replace_char_in_str(string: str, new_char: str, idx: int) -> str:
|
def _replace_char_in_str(string: str, new_char: str, idx: int) -> str:
|
||||||
|
|
@ -4,9 +4,9 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed._tensor._op_schema import OpSchema, OutputSharding
|
from torch.distributed.tensor._op_schema import OpSchema, OutputSharding
|
||||||
from torch.distributed._tensor.ops.utils import register_prop_rule
|
from torch.distributed.tensor._ops.utils import register_prop_rule
|
||||||
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
|
from torch.distributed.tensor.placement_types import DTensorSpec, TensorMeta
|
||||||
|
|
||||||
|
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
|
|
@ -2,15 +2,15 @@ import itertools
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Set, Tuple
|
from typing import List, Set, Tuple
|
||||||
|
|
||||||
from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy
|
from torch.distributed.device_mesh import DeviceMesh
|
||||||
from torch.distributed._tensor.placement_types import (
|
from torch.distributed.tensor._op_schema import OpStrategy, PlacementStrategy
|
||||||
|
from torch.distributed.tensor.placement_types import (
|
||||||
DTensorSpec,
|
DTensorSpec,
|
||||||
Partial,
|
Partial,
|
||||||
Placement,
|
Placement,
|
||||||
Replicate,
|
Replicate,
|
||||||
Shard,
|
Shard,
|
||||||
)
|
)
|
||||||
from torch.distributed.device_mesh import DeviceMesh
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -7,23 +7,23 @@ from typing import cast, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed._functional_collectives as funcol
|
import torch.distributed._functional_collectives as funcol
|
||||||
from torch.distributed._tensor._op_schema import (
|
from torch.distributed.device_mesh import DeviceMesh
|
||||||
|
from torch.distributed.tensor._op_schema import (
|
||||||
OpSchema,
|
OpSchema,
|
||||||
OpStrategy,
|
OpStrategy,
|
||||||
PlacementList,
|
PlacementList,
|
||||||
StrategyType,
|
StrategyType,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor.ops.utils import (
|
from torch.distributed.tensor._ops.utils import (
|
||||||
expand_to_full_mesh_op_strategy,
|
expand_to_full_mesh_op_strategy,
|
||||||
register_op_strategy,
|
register_op_strategy,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor.placement_types import (
|
from torch.distributed.tensor.placement_types import (
|
||||||
Partial,
|
Partial,
|
||||||
Placement,
|
Placement,
|
||||||
Replicate,
|
Replicate,
|
||||||
Shard,
|
Shard,
|
||||||
)
|
)
|
||||||
from torch.distributed.device_mesh import DeviceMesh
|
|
||||||
|
|
||||||
|
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
|
|
@ -3,15 +3,15 @@
|
||||||
# implement matrix related ops for distributed tensor
|
# implement matrix related ops for distributed tensor
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed._tensor._op_schema import (
|
from torch.distributed.tensor._op_schema import (
|
||||||
OpSchema,
|
OpSchema,
|
||||||
OpStrategy,
|
OpStrategy,
|
||||||
PlacementStrategy,
|
PlacementStrategy,
|
||||||
StrategyType,
|
StrategyType,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor.device_mesh import DeviceMesh
|
from torch.distributed.tensor._ops.utils import register_op_strategy
|
||||||
from torch.distributed._tensor.ops.utils import register_op_strategy
|
from torch.distributed.tensor.device_mesh import DeviceMesh
|
||||||
from torch.distributed._tensor.placement_types import DTensorSpec, Replicate
|
from torch.distributed.tensor.placement_types import DTensorSpec, Replicate
|
||||||
|
|
||||||
|
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
|
|
@ -7,7 +7,8 @@ from enum import Enum
|
||||||
from typing import cast, List, Optional, Sequence, Tuple, Union
|
from typing import cast, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed._tensor._op_schema import (
|
from torch.distributed.device_mesh import DeviceMesh
|
||||||
|
from torch.distributed.tensor._op_schema import (
|
||||||
OpSchema,
|
OpSchema,
|
||||||
OpStrategy,
|
OpStrategy,
|
||||||
PlacementList,
|
PlacementList,
|
||||||
|
|
@ -15,8 +16,7 @@ from torch.distributed._tensor._op_schema import (
|
||||||
RuntimeSchemaInfo,
|
RuntimeSchemaInfo,
|
||||||
TupleStrategy,
|
TupleStrategy,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor._utils import normalize_to_torch_size
|
from torch.distributed.tensor._ops.utils import (
|
||||||
from torch.distributed._tensor.ops.utils import (
|
|
||||||
as_list,
|
as_list,
|
||||||
expand_to_full_mesh_op_strategy,
|
expand_to_full_mesh_op_strategy,
|
||||||
generate_redistribute_costs,
|
generate_redistribute_costs,
|
||||||
|
|
@ -25,14 +25,14 @@ from torch.distributed._tensor.ops.utils import (
|
||||||
normalize_dims,
|
normalize_dims,
|
||||||
register_op_strategy,
|
register_op_strategy,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor.placement_types import (
|
from torch.distributed.tensor._utils import normalize_to_torch_size
|
||||||
|
from torch.distributed.tensor.placement_types import (
|
||||||
DTensorSpec,
|
DTensorSpec,
|
||||||
Partial,
|
Partial,
|
||||||
Placement,
|
Placement,
|
||||||
Replicate,
|
Replicate,
|
||||||
Shard,
|
Shard,
|
||||||
)
|
)
|
||||||
from torch.distributed.device_mesh import DeviceMesh
|
|
||||||
|
|
||||||
|
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
|
|
@ -5,14 +5,15 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed._tensor._op_schema import (
|
from torch.distributed.device_mesh import DeviceMesh
|
||||||
|
from torch.distributed.tensor._op_schema import (
|
||||||
OpSchema,
|
OpSchema,
|
||||||
OpStrategy,
|
OpStrategy,
|
||||||
PlacementList,
|
PlacementList,
|
||||||
PlacementStrategy,
|
PlacementStrategy,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor.ops._einsum_strategy import gen_einsum_strategies
|
from torch.distributed.tensor._ops._einsum_strategy import gen_einsum_strategies
|
||||||
from torch.distributed._tensor.ops.utils import (
|
from torch.distributed.tensor._ops.utils import (
|
||||||
expand_to_full_mesh_op_strategy,
|
expand_to_full_mesh_op_strategy,
|
||||||
generate_redistribute_costs,
|
generate_redistribute_costs,
|
||||||
infer_broadcast_dims_map,
|
infer_broadcast_dims_map,
|
||||||
|
|
@ -20,13 +21,12 @@ from torch.distributed._tensor.ops.utils import (
|
||||||
map_placements_after_broadcast,
|
map_placements_after_broadcast,
|
||||||
register_op_strategy,
|
register_op_strategy,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor.placement_types import (
|
from torch.distributed.tensor.placement_types import (
|
||||||
DTensorSpec,
|
DTensorSpec,
|
||||||
Placement,
|
Placement,
|
||||||
Replicate,
|
Replicate,
|
||||||
Shard,
|
Shard,
|
||||||
)
|
)
|
||||||
from torch.distributed.device_mesh import DeviceMesh
|
|
||||||
|
|
||||||
|
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
|
|
@ -2,7 +2,8 @@
|
||||||
from typing import List, Sequence, Tuple
|
from typing import List, Sequence, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed._tensor._op_schema import (
|
from torch.distributed.device_mesh import DeviceMesh
|
||||||
|
from torch.distributed.tensor._op_schema import (
|
||||||
_is_inplace_op,
|
_is_inplace_op,
|
||||||
_is_out_variant_op,
|
_is_out_variant_op,
|
||||||
OpSchema,
|
OpSchema,
|
||||||
|
|
@ -12,21 +13,20 @@ from torch.distributed._tensor._op_schema import (
|
||||||
StrategyType,
|
StrategyType,
|
||||||
TupleStrategy,
|
TupleStrategy,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor.ops.utils import (
|
from torch.distributed.tensor._ops.utils import (
|
||||||
generate_redistribute_costs,
|
generate_redistribute_costs,
|
||||||
infer_broadcast_dims_map,
|
infer_broadcast_dims_map,
|
||||||
map_placements_after_broadcast,
|
map_placements_after_broadcast,
|
||||||
normalize_dim,
|
normalize_dim,
|
||||||
register_op_strategy,
|
register_op_strategy,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor.placement_types import (
|
from torch.distributed.tensor.placement_types import (
|
||||||
DTensorSpec,
|
DTensorSpec,
|
||||||
Partial,
|
Partial,
|
||||||
Placement,
|
Placement,
|
||||||
Replicate,
|
Replicate,
|
||||||
Shard,
|
Shard,
|
||||||
)
|
)
|
||||||
from torch.distributed.device_mesh import DeviceMesh
|
|
||||||
|
|
||||||
|
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
|
|
@ -1,14 +1,14 @@
|
||||||
# mypy: allow-untyped-decorators
|
# mypy: allow-untyped-decorators
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed._tensor._op_schema import (
|
from torch.distributed.device_mesh import DeviceMesh
|
||||||
|
from torch.distributed.tensor._op_schema import (
|
||||||
OpSchema,
|
OpSchema,
|
||||||
OpStrategy,
|
OpStrategy,
|
||||||
PlacementStrategy,
|
PlacementStrategy,
|
||||||
StrategyType,
|
StrategyType,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor.ops.utils import is_tensor_partial, register_op_strategy
|
from torch.distributed.tensor._ops.utils import is_tensor_partial, register_op_strategy
|
||||||
from torch.distributed.device_mesh import DeviceMesh
|
|
||||||
|
|
||||||
|
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
|
|
@ -4,7 +4,8 @@
|
||||||
from typing import cast, List, Optional, Sequence, Tuple
|
from typing import cast, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed._tensor._op_schema import (
|
from torch.distributed.device_mesh import DeviceMesh
|
||||||
|
from torch.distributed.tensor._op_schema import (
|
||||||
_is_inplace_op,
|
_is_inplace_op,
|
||||||
OpSchema,
|
OpSchema,
|
||||||
OpStrategy,
|
OpStrategy,
|
||||||
|
|
@ -15,9 +16,9 @@ from torch.distributed._tensor._op_schema import (
|
||||||
StrategyType,
|
StrategyType,
|
||||||
TupleStrategy,
|
TupleStrategy,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor.ops._common_rules import pointwise_rule
|
from torch.distributed.tensor._ops._common_rules import pointwise_rule
|
||||||
from torch.distributed._tensor.ops._embedding_ops import _MaskPartial
|
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
|
||||||
from torch.distributed._tensor.ops.utils import (
|
from torch.distributed.tensor._ops.utils import (
|
||||||
expand_to_full_mesh_op_strategy,
|
expand_to_full_mesh_op_strategy,
|
||||||
is_tensor_dim_sharded,
|
is_tensor_dim_sharded,
|
||||||
is_tensor_evenly_shardable,
|
is_tensor_evenly_shardable,
|
||||||
|
|
@ -26,14 +27,13 @@ from torch.distributed._tensor.ops.utils import (
|
||||||
register_op_strategy,
|
register_op_strategy,
|
||||||
register_prop_rule,
|
register_prop_rule,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor.placement_types import (
|
from torch.distributed.tensor.placement_types import (
|
||||||
DTensorSpec,
|
DTensorSpec,
|
||||||
Partial,
|
Partial,
|
||||||
Placement,
|
Placement,
|
||||||
Replicate,
|
Replicate,
|
||||||
Shard,
|
Shard,
|
||||||
)
|
)
|
||||||
from torch.distributed.device_mesh import DeviceMesh
|
|
||||||
|
|
||||||
|
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
|
|
@ -17,23 +17,27 @@ from typing import (
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.distributed._tensor._op_schema import (
|
from torch.distributed.device_mesh import DeviceMesh
|
||||||
|
from torch.distributed.tensor._op_schema import (
|
||||||
OpSchema,
|
OpSchema,
|
||||||
OpStrategy,
|
OpStrategy,
|
||||||
PlacementStrategy,
|
PlacementStrategy,
|
||||||
RuntimeSchemaInfo,
|
RuntimeSchemaInfo,
|
||||||
StrategyType,
|
StrategyType,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor.api import Shard
|
from torch.distributed.tensor._ops.utils import (
|
||||||
from torch.distributed._tensor.ops.utils import (
|
|
||||||
generate_redistribute_costs,
|
generate_redistribute_costs,
|
||||||
normalize_dim,
|
normalize_dim,
|
||||||
normalize_dims,
|
normalize_dims,
|
||||||
prod,
|
prod,
|
||||||
register_op_strategy,
|
register_op_strategy,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor.placement_types import DTensorSpec, Placement, Replicate
|
from torch.distributed.tensor.placement_types import (
|
||||||
from torch.distributed.device_mesh import DeviceMesh
|
DTensorSpec,
|
||||||
|
Placement,
|
||||||
|
Replicate,
|
||||||
|
Shard,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
|
|
@ -6,17 +6,17 @@ import operator
|
||||||
from typing import cast, Iterable, List, Optional, Sequence, Tuple, Union
|
from typing import cast, Iterable, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed._tensor._collective_utils import redistribute_cost
|
from torch.distributed.tensor._collective_utils import redistribute_cost
|
||||||
from torch.distributed._tensor._op_schema import (
|
from torch.distributed.tensor._op_schema import (
|
||||||
OpSchema,
|
OpSchema,
|
||||||
OpStrategy,
|
OpStrategy,
|
||||||
PlacementList,
|
PlacementList,
|
||||||
PlacementStrategy,
|
PlacementStrategy,
|
||||||
RuntimeSchemaInfo,
|
RuntimeSchemaInfo,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor.api import DTensor
|
from torch.distributed.tensor.api import DTensor
|
||||||
from torch.distributed._tensor.device_mesh import DeviceMesh
|
from torch.distributed.tensor.device_mesh import DeviceMesh
|
||||||
from torch.distributed._tensor.placement_types import (
|
from torch.distributed.tensor.placement_types import (
|
||||||
DTensorSpec,
|
DTensorSpec,
|
||||||
Partial,
|
Partial,
|
||||||
Placement,
|
Placement,
|
||||||
|
|
@ -6,9 +6,9 @@ from typing import cast, List, NamedTuple, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed._functional_collectives as funcol
|
import torch.distributed._functional_collectives as funcol
|
||||||
import torch.distributed._tensor.api as dtensor
|
import torch.distributed.tensor.api as dtensor
|
||||||
from torch.distributed._tensor.device_mesh import DeviceMesh
|
from torch.distributed.tensor.device_mesh import DeviceMesh
|
||||||
from torch.distributed._tensor.placement_types import (
|
from torch.distributed.tensor.placement_types import (
|
||||||
DTensorSpec,
|
DTensorSpec,
|
||||||
Partial,
|
Partial,
|
||||||
Placement,
|
Placement,
|
||||||
|
|
@ -6,7 +6,8 @@ from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
from torch._ops import OpOverload
|
from torch._ops import OpOverload
|
||||||
from torch._subclasses import FakeTensorMode
|
from torch._subclasses import FakeTensorMode
|
||||||
from torch.distributed._tensor._op_schema import (
|
from torch.distributed.device_mesh import DeviceMesh
|
||||||
|
from torch.distributed.tensor._op_schema import (
|
||||||
OpInfo,
|
OpInfo,
|
||||||
OpSchema,
|
OpSchema,
|
||||||
OpStrategy,
|
OpStrategy,
|
||||||
|
|
@ -17,13 +18,12 @@ from torch.distributed._tensor._op_schema import (
|
||||||
StrategyType,
|
StrategyType,
|
||||||
TupleStrategy,
|
TupleStrategy,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor._utils import (
|
from torch.distributed.tensor._utils import (
|
||||||
compute_local_shape,
|
compute_local_shape,
|
||||||
compute_local_stride,
|
compute_local_stride,
|
||||||
try_find_mesh_from_args,
|
try_find_mesh_from_args,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
|
from torch.distributed.tensor.placement_types import DTensorSpec, TensorMeta
|
||||||
from torch.distributed.device_mesh import DeviceMesh
|
|
||||||
|
|
||||||
|
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
|
|
@ -5,7 +5,7 @@ from typing import cast, Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.distributed._tensor.api as dtensor
|
import torch.distributed.tensor.api as dtensor
|
||||||
|
|
||||||
|
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
from typing import cast, List, Sequence, Tuple
|
from typing import cast, List, Sequence, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed._tensor.api as dtensor
|
import torch.distributed.tensor.api as dtensor
|
||||||
from torch._prims_common import ShapeType
|
from torch._prims_common import ShapeType
|
||||||
from torch.distributed._tensor.placement_types import (
|
from torch.distributed.device_mesh import DeviceMesh
|
||||||
|
from torch.distributed.tensor.placement_types import (
|
||||||
_StridedShard,
|
_StridedShard,
|
||||||
DTensorSpec,
|
DTensorSpec,
|
||||||
Partial,
|
Partial,
|
||||||
|
|
@ -11,7 +12,6 @@ from torch.distributed._tensor.placement_types import (
|
||||||
Replicate,
|
Replicate,
|
||||||
Shard,
|
Shard,
|
||||||
)
|
)
|
||||||
from torch.distributed.device_mesh import DeviceMesh
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: audit existing code base to see if we can safely remove this API.
|
# TODO: audit existing code base to see if we can safely remove this API.
|
||||||
|
|
@ -6,23 +6,21 @@ import warnings
|
||||||
from typing import Any, Callable, cast, Optional, Sequence, Tuple
|
from typing import Any, Callable, cast, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed._tensor._dispatch as op_dispatch
|
import torch.distributed.tensor._dispatch as op_dispatch
|
||||||
import torch.distributed._tensor.random as random
|
import torch.distributed.tensor.random as random
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed._tensor._collective_utils import (
|
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
|
||||||
check_tensor_meta,
|
from torch.distributed.tensor._collective_utils import check_tensor_meta, mesh_broadcast
|
||||||
mesh_broadcast,
|
from torch.distributed.tensor._redistribute import (
|
||||||
)
|
|
||||||
from torch.distributed._tensor._redistribute import (
|
|
||||||
Redistribute,
|
Redistribute,
|
||||||
redistribute_local_tensor,
|
redistribute_local_tensor,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor._utils import (
|
from torch.distributed.tensor._utils import (
|
||||||
compute_global_tensor_info,
|
compute_global_tensor_info,
|
||||||
compute_local_shape,
|
compute_local_shape,
|
||||||
normalize_to_torch_size,
|
normalize_to_torch_size,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor.placement_types import (
|
from torch.distributed.tensor.placement_types import (
|
||||||
DTensorSpec,
|
DTensorSpec,
|
||||||
Partial,
|
Partial,
|
||||||
Placement,
|
Placement,
|
||||||
|
|
@ -30,11 +28,7 @@ from torch.distributed._tensor.placement_types import (
|
||||||
Shard,
|
Shard,
|
||||||
TensorMeta,
|
TensorMeta,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor.random import (
|
from torch.distributed.tensor.random import is_rng_supported_mesh, OffsetBasedRNGTracker
|
||||||
is_rng_supported_mesh,
|
|
||||||
OffsetBasedRNGTracker,
|
|
||||||
)
|
|
||||||
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
@ -254,7 +248,8 @@ class DTensor(torch.Tensor):
|
||||||
"""
|
"""
|
||||||
Construct a DTensor from a local tensor, device mesh, and placement and
|
Construct a DTensor from a local tensor, device mesh, and placement and
|
||||||
other tensor properties (i.e. shape, requires_grad, strides, etc).
|
other tensor properties (i.e. shape, requires_grad, strides, etc).
|
||||||
Note: This is not a public API and it's only supposed to be used by the
|
|
||||||
|
.. note:: This is not a public API and it's only supposed to be used by the
|
||||||
operator implementations and internals. If you want to construct a
|
operator implementations and internals. If you want to construct a
|
||||||
DTensor from a local tensor, consider using ``DTensor.from_local``, if
|
DTensor from a local tensor, consider using ``DTensor.from_local``, if
|
||||||
you want to construct a DTensor from a "global" tensor (where you
|
you want to construct a DTensor from a "global" tensor (where you
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
from torch.distributed._tensor.debug.comm_mode import CommDebugMode
|
from torch.distributed.tensor.debug.comm_mode import CommDebugMode
|
||||||
from torch.distributed._tensor.debug.visualize_sharding import visualize_sharding
|
from torch.distributed.tensor.debug.visualize_sharding import visualize_sharding
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["CommDebugMode", "visualize_sharding"]
|
__all__ = ["CommDebugMode", "visualize_sharding"]
|
||||||
|
|
@ -12,7 +12,7 @@ def _get_sharding_prop_cache_info():
|
||||||
This would return a named tuple showing hits, misses, maxsize and cursize of the sharding
|
This would return a named tuple showing hits, misses, maxsize and cursize of the sharding
|
||||||
propagator cache.
|
propagator cache.
|
||||||
"""
|
"""
|
||||||
from torch.distributed._tensor.api import DTensor
|
from torch.distributed.tensor.api import DTensor
|
||||||
|
|
||||||
return (
|
return (
|
||||||
DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding.cache_info() # type:ignore[attr-defined]
|
DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding.cache_info() # type:ignore[attr-defined]
|
||||||
|
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
||||||
from functorch.compile import make_boxed_func
|
from functorch.compile import make_boxed_func
|
||||||
from torch._functorch.compilers import aot_module
|
from torch._functorch.compilers import aot_module
|
||||||
from torch._inductor.decomposition import select_decomp_table
|
from torch._inductor.decomposition import select_decomp_table
|
||||||
from torch.distributed._tensor import DTensor
|
from torch.distributed.tensor import DTensor
|
||||||
|
|
||||||
|
|
||||||
inductor_decomps = select_decomp_table()
|
inductor_decomps = select_decomp_table()
|
||||||
|
|
@ -7,11 +7,11 @@ from collections import defaultdict
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed._tools.mod_tracker as mod_tracker
|
||||||
import torch.nn
|
import torch.nn
|
||||||
from torch._guards import detect_fake_mode
|
from torch._guards import detect_fake_mode
|
||||||
from torch.autograd.graph import register_multi_grad_hook
|
from torch.autograd.graph import register_multi_grad_hook
|
||||||
from torch.distributed._tensor.api import DTensor
|
from torch.distributed.tensor.api import DTensor
|
||||||
from torch.distributed._tools.mod_tracker import ModTracker
|
|
||||||
from torch.nn.modules.module import (
|
from torch.nn.modules.module import (
|
||||||
register_module_forward_hook,
|
register_module_forward_hook,
|
||||||
register_module_forward_pre_hook,
|
register_module_forward_pre_hook,
|
||||||
|
|
@ -69,7 +69,7 @@ trivial_ops = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class CommModeModuleTracker(ModTracker):
|
class _CommModeModuleTracker(mod_tracker.ModTracker):
|
||||||
"""
|
"""
|
||||||
Inherits ModuleTracker and expands on its functionality to track the
|
Inherits ModuleTracker and expands on its functionality to track the
|
||||||
parameters and sharding information of a model at a module-level
|
parameters and sharding information of a model at a module-level
|
||||||
|
|
@ -250,7 +250,7 @@ class CommDebugMode(TorchDispatchMode):
|
||||||
self.comm_registry.add(py_op)
|
self.comm_registry.add(py_op)
|
||||||
|
|
||||||
self.comm_registry.add(torch.ops._dtensor.shard_dim_alltoall)
|
self.comm_registry.add(torch.ops._dtensor.shard_dim_alltoall)
|
||||||
self.advanced_module_tracker = CommModeModuleTracker()
|
self.advanced_module_tracker = _CommModeModuleTracker()
|
||||||
|
|
||||||
def generate_json_dump(self, file_name="comm_mode_log.json", noise_level=3):
|
def generate_json_dump(self, file_name="comm_mode_log.json", noise_level=3):
|
||||||
"""
|
"""
|
||||||
|
|
@ -4,8 +4,8 @@ from typing import List, Sequence, Tuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from torch._prims_common import ShapeType
|
from torch._prims_common import ShapeType
|
||||||
from torch.distributed._tensor import DeviceMesh
|
from torch.distributed.tensor import DeviceMesh
|
||||||
from torch.distributed._tensor.placement_types import Placement, Shard
|
from torch.distributed.tensor.placement_types import Placement, Shard
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["visualize_sharding"]
|
__all__ = ["visualize_sharding"]
|
||||||
|
|
@ -8,8 +8,8 @@ from typing import Callable, Dict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed._tensor import DeviceMesh
|
from torch.distributed.tensor import DeviceMesh
|
||||||
from torch.distributed._tensor.debug import CommDebugMode
|
from torch.distributed.tensor.debug import CommDebugMode
|
||||||
from torch.distributed.tensor.parallel import (
|
from torch.distributed.tensor.parallel import (
|
||||||
ColwiseParallel,
|
ColwiseParallel,
|
||||||
parallelize_module,
|
parallelize_module,
|
||||||
|
|
@ -12,7 +12,7 @@ import time
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed._tensor import (
|
from torch.distributed.tensor import (
|
||||||
DeviceMesh,
|
DeviceMesh,
|
||||||
distribute_module,
|
distribute_module,
|
||||||
distribute_tensor,
|
distribute_tensor,
|
||||||
|
|
@ -9,23 +9,23 @@ from functools import cached_property
|
||||||
from typing import List, TYPE_CHECKING
|
from typing import List, TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed._tensor import (
|
from torch.distributed.checkpoint.metadata import (
|
||||||
|
ChunkStorageMetadata,
|
||||||
|
TensorProperties,
|
||||||
|
TensorStorageMetadata,
|
||||||
|
)
|
||||||
|
from torch.distributed.tensor import (
|
||||||
DeviceMesh,
|
DeviceMesh,
|
||||||
DTensor,
|
DTensor,
|
||||||
init_device_mesh,
|
init_device_mesh,
|
||||||
Replicate,
|
Replicate,
|
||||||
Shard,
|
Shard,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor.debug.visualize_sharding import visualize_sharding
|
from torch.distributed.tensor.debug.visualize_sharding import visualize_sharding
|
||||||
from torch.distributed.checkpoint.metadata import (
|
|
||||||
ChunkStorageMetadata,
|
|
||||||
TensorProperties,
|
|
||||||
TensorStorageMetadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from torch.distributed._tensor.placement_types import Placement
|
from torch.distributed.tensor.placement_types import Placement
|
||||||
|
|
||||||
|
|
||||||
def get_device_type():
|
def get_device_type():
|
||||||
|
|
@ -6,8 +6,8 @@ torchrun --standalone --nnodes=1 --nproc-per-node=4 visualize_sharding_example.p
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard
|
from torch.distributed.tensor import DeviceMesh, distribute_tensor, Replicate, Shard
|
||||||
from torch.distributed._tensor.debug.visualize_sharding import visualize_sharding
|
from torch.distributed.tensor.debug.visualize_sharding import visualize_sharding
|
||||||
|
|
||||||
|
|
||||||
world_size = int(os.environ["WORLD_SIZE"])
|
world_size = int(os.environ["WORLD_SIZE"])
|
||||||
|
|
@ -2,9 +2,9 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
from torch.distributed._tensor.api import DTensor
|
from torch.distributed.tensor.api import DTensor
|
||||||
from torch.distributed._tensor.experimental.func_map import local_map
|
from torch.distributed.tensor.experimental.func_map import local_map
|
||||||
from torch.distributed._tensor.experimental.register_sharding import register_sharding
|
from torch.distributed.tensor.experimental.register_sharding import register_sharding
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["implicit_replication", "local_map", "register_sharding"]
|
__all__ = ["implicit_replication", "local_map", "register_sharding"]
|
||||||
|
|
@ -24,8 +24,8 @@ import torch.distributed as dist
|
||||||
import torch.distributed._functional_collectives as ft_c
|
import torch.distributed._functional_collectives as ft_c
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.distributed._tensor import distribute_module, DTensor, Replicate, Shard
|
|
||||||
from torch.distributed.device_mesh import DeviceMesh
|
from torch.distributed.device_mesh import DeviceMesh
|
||||||
|
from torch.distributed.tensor import distribute_module, DTensor, Replicate, Shard
|
||||||
from torch.distributed.tensor.parallel.style import ParallelStyle
|
from torch.distributed.tensor.parallel.style import ParallelStyle
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -4,8 +4,8 @@ from typing import Callable, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
||||||
from torch.distributed._tensor import DeviceMesh, DTensor
|
from torch.distributed.tensor import DeviceMesh, DTensor
|
||||||
from torch.distributed._tensor.placement_types import Placement
|
from torch.distributed.tensor.placement_types import Placement
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -16,7 +16,6 @@ except ImportError:
|
||||||
|
|
||||||
__all__ = ["local_map"]
|
__all__ = ["local_map"]
|
||||||
|
|
||||||
|
|
||||||
PlacementType = Optional[Sequence[Placement]]
|
PlacementType = Optional[Sequence[Placement]]
|
||||||
InputPlacements = Optional[Tuple[PlacementType, ...]]
|
InputPlacements = Optional[Tuple[PlacementType, ...]]
|
||||||
OutputPlacements = Union[PlacementType, Tuple[PlacementType, ...]]
|
OutputPlacements = Union[PlacementType, Tuple[PlacementType, ...]]
|
||||||
|
|
@ -5,8 +5,8 @@ from typing import Callable, List, Sequence, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._ops import OpOverload
|
from torch._ops import OpOverload
|
||||||
from torch.distributed._tensor import DeviceMesh, DTensor
|
from torch.distributed.tensor import DeviceMesh, DTensor
|
||||||
from torch.distributed._tensor._op_schema import (
|
from torch.distributed.tensor._op_schema import (
|
||||||
_is_inplace_op,
|
_is_inplace_op,
|
||||||
OpSchema,
|
OpSchema,
|
||||||
OpStrategy,
|
OpStrategy,
|
||||||
|
|
@ -15,7 +15,7 @@ from torch.distributed._tensor._op_schema import (
|
||||||
StrategyType,
|
StrategyType,
|
||||||
TupleStrategy,
|
TupleStrategy,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor.ops.utils import expand_to_full_mesh_op_strategy
|
from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["register_sharding"]
|
__all__ = ["register_sharding"]
|
||||||
|
|
@ -5,22 +5,22 @@ from typing import Any, cast, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._subclasses.fake_tensor import FakeTensor
|
from torch._subclasses.fake_tensor import FakeTensor
|
||||||
from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor
|
from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor
|
||||||
from torch.distributed._tensor._op_schema import (
|
from torch.distributed.tensor._op_schema import (
|
||||||
DTensorSpec,
|
DTensorSpec,
|
||||||
OpSchema,
|
OpSchema,
|
||||||
OutputSharding,
|
OutputSharding,
|
||||||
OutputSpecType,
|
OutputSpecType,
|
||||||
PlacementStrategy,
|
PlacementStrategy,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor._redistribute import redistribute_local_tensor
|
from torch.distributed.tensor._redistribute import redistribute_local_tensor
|
||||||
from torch.distributed._tensor.placement_types import (
|
from torch.distributed.tensor.parallel.style import ColwiseParallel, ParallelStyle
|
||||||
|
from torch.distributed.tensor.placement_types import (
|
||||||
Placement,
|
Placement,
|
||||||
Replicate,
|
Replicate,
|
||||||
Shard,
|
Shard,
|
||||||
TensorMeta,
|
TensorMeta,
|
||||||
)
|
)
|
||||||
from torch.distributed.tensor.parallel.style import ColwiseParallel, ParallelStyle
|
|
||||||
from torch.export import ExportedProgram
|
from torch.export import ExportedProgram
|
||||||
from torch.export.exported_program import ExportGraphSignature
|
from torch.export.exported_program import ExportGraphSignature
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
|
|
@ -3,8 +3,8 @@ from typing import no_type_check, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
||||||
from torch.distributed._tensor import DTensor
|
from torch.distributed.tensor import DTensor
|
||||||
from torch.distributed._tensor.placement_types import DTensorSpec
|
from torch.distributed.tensor.placement_types import DTensorSpec
|
||||||
|
|
||||||
|
|
||||||
@no_type_check
|
@no_type_check
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,9 @@
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
from torch.distributed._tensor import DeviceMesh
|
|
||||||
from torch.distributed._tensor.placement_types import Placement
|
|
||||||
from torch.distributed.device_mesh import _mesh_resources
|
from torch.distributed.device_mesh import _mesh_resources
|
||||||
|
from torch.distributed.tensor import DeviceMesh
|
||||||
|
from torch.distributed.tensor.placement_types import Placement
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -3,15 +3,15 @@ from fnmatch import fnmatch
|
||||||
from typing import Dict, Union
|
from typing import Dict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed._tensor.random as random
|
import torch.distributed.tensor.random as random
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed._tensor import DeviceMesh
|
from torch.distributed.tensor import DeviceMesh
|
||||||
from torch.distributed._tensor.random import (
|
from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim
|
||||||
|
from torch.distributed.tensor.parallel.style import ParallelStyle
|
||||||
|
from torch.distributed.tensor.random import (
|
||||||
is_rng_supported_mesh,
|
is_rng_supported_mesh,
|
||||||
TensorParallelRNGTracker,
|
TensorParallelRNGTracker,
|
||||||
)
|
)
|
||||||
from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim
|
|
||||||
from torch.distributed.tensor.parallel.style import ParallelStyle
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
|
||||||
|
|
@ -14,12 +14,12 @@ from torch.distributed._shard.sharded_tensor import (
|
||||||
)
|
)
|
||||||
from torch.distributed._shard.sharding_spec import ShardMetadata
|
from torch.distributed._shard.sharding_spec import ShardMetadata
|
||||||
from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec
|
from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec
|
||||||
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard as DShard
|
|
||||||
from torch.distributed.device_mesh import _mesh_resources
|
from torch.distributed.device_mesh import _mesh_resources
|
||||||
from torch.distributed.fsdp._common_utils import _set_fsdp_flattened
|
from torch.distributed.fsdp._common_utils import _set_fsdp_flattened
|
||||||
from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions
|
from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions
|
||||||
from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
|
from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
|
||||||
from torch.distributed.remote_device import _remote_device
|
from torch.distributed.remote_device import _remote_device
|
||||||
|
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard
|
||||||
from torch.distributed.tensor.parallel._data_parallel_utils import (
|
from torch.distributed.tensor.parallel._data_parallel_utils import (
|
||||||
_flatten_tensor,
|
_flatten_tensor,
|
||||||
_unflatten_tensor,
|
_unflatten_tensor,
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from functools import partial
|
||||||
from typing import Any, Optional, Tuple
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
|
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
|
||||||
|
|
@ -8,15 +8,15 @@ import torch._prims_common as utils
|
||||||
import torch.distributed._functional_collectives as funcol
|
import torch.distributed._functional_collectives as funcol
|
||||||
import torch.distributed.distributed_c10d as c10d
|
import torch.distributed.distributed_c10d as c10d
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.distributed._tensor import DTensor, Replicate, Shard
|
from torch.distributed.device_mesh import DeviceMesh
|
||||||
from torch.distributed._tensor.ops._embedding_ops import _MaskPartial
|
from torch.distributed.tensor import DTensor, Replicate, Shard
|
||||||
from torch.distributed._tensor.ops._math_ops import (
|
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
|
||||||
|
from torch.distributed.tensor._ops._math_ops import (
|
||||||
_skip_dim,
|
_skip_dim,
|
||||||
Reduction,
|
Reduction,
|
||||||
replicate_reduction_dims,
|
replicate_reduction_dims,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor.placement_types import DTensorSpec, Placement, TensorMeta
|
from torch.distributed.tensor.placement_types import DTensorSpec, Placement, TensorMeta
|
||||||
from torch.distributed.device_mesh import DeviceMesh
|
|
||||||
|
|
||||||
|
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed._tensor import (
|
from torch.distributed.tensor import (
|
||||||
DeviceMesh,
|
DeviceMesh,
|
||||||
distribute_module,
|
distribute_module,
|
||||||
distribute_tensor,
|
distribute_tensor,
|
||||||
|
|
@ -14,7 +14,7 @@ from torch.distributed._tensor import (
|
||||||
Replicate,
|
Replicate,
|
||||||
Shard,
|
Shard,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor.placement_types import Placement
|
from torch.distributed.tensor.placement_types import Placement
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,8 @@ from typing import Any, cast, List, NamedTuple, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed._functional_collectives as funcol
|
import torch.distributed._functional_collectives as funcol
|
||||||
from torch.distributed._tensor._collective_utils import (
|
from torch.distributed.device_mesh import DeviceMesh
|
||||||
|
from torch.distributed.tensor._collective_utils import (
|
||||||
fill_empty_tensor_to_shards,
|
fill_empty_tensor_to_shards,
|
||||||
mesh_broadcast,
|
mesh_broadcast,
|
||||||
mesh_scatter,
|
mesh_scatter,
|
||||||
|
|
@ -14,7 +15,6 @@ from torch.distributed._tensor._collective_utils import (
|
||||||
shard_dim_alltoall,
|
shard_dim_alltoall,
|
||||||
unpad_tensor,
|
unpad_tensor,
|
||||||
)
|
)
|
||||||
from torch.distributed.device_mesh import DeviceMesh
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Placement", "Shard", "Replicate", "Partial", "DTensorSpec", "TensorMeta"]
|
__all__ = ["Placement", "Shard", "Replicate", "Partial", "DTensorSpec", "TensorMeta"]
|
||||||
|
|
@ -7,8 +7,8 @@ from typing import Dict, List, Optional
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.distributed._tensor.placement_types import DTensorSpec, Shard
|
|
||||||
from torch.distributed.device_mesh import _get_device_handle, DeviceMesh
|
from torch.distributed.device_mesh import _get_device_handle, DeviceMesh
|
||||||
|
from torch.distributed.tensor.placement_types import DTensorSpec, Shard
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
@ -290,7 +290,7 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
|
||||||
return_offset=False,
|
return_offset=False,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
from torch.distributed._tensor.ops.utils import prod
|
from torch.distributed.tensor._ops.utils import prod
|
||||||
|
|
||||||
local_size = prod(local_size_on_rank_0)
|
local_size = prod(local_size_on_rank_0)
|
||||||
|
|
||||||
|
|
@ -317,7 +317,7 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
|
||||||
"""
|
"""
|
||||||
dtensor_shape = spec.shape
|
dtensor_shape = spec.shape
|
||||||
|
|
||||||
from torch.distributed._tensor.ops.utils import prod
|
from torch.distributed.tensor._ops.utils import prod
|
||||||
|
|
||||||
numel = prod(dtensor_shape)
|
numel = prod(dtensor_shape)
|
||||||
# pytorch: offset must be multiple of 4
|
# pytorch: offset must be multiple of 4
|
||||||
|
|
@ -34,7 +34,6 @@ from torch.distributed._composable.fsdp._fsdp_param_group import (
|
||||||
FSDPParamGroup,
|
FSDPParamGroup,
|
||||||
RegisterPostBackwardFunction,
|
RegisterPostBackwardFunction,
|
||||||
)
|
)
|
||||||
from torch.distributed._tensor import distribute_tensor, DTensor, Shard
|
|
||||||
from torch.distributed.device_mesh import DeviceMesh
|
from torch.distributed.device_mesh import DeviceMesh
|
||||||
from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel as FSDP
|
from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel as FSDP
|
||||||
from torch.distributed.fsdp._common_utils import TrainingState
|
from torch.distributed.fsdp._common_utils import TrainingState
|
||||||
|
|
@ -46,6 +45,7 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||||
)
|
)
|
||||||
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
||||||
from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy, wrap
|
from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy, wrap
|
||||||
|
from torch.distributed.tensor import distribute_tensor, DTensor, Shard
|
||||||
from torch.distributed.tensor.parallel import (
|
from torch.distributed.tensor.parallel import (
|
||||||
ColwiseParallel,
|
ColwiseParallel,
|
||||||
parallelize_module,
|
parallelize_module,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user