Revert "[dtensor] move tensor constructors to a separate module (#133129)"

This reverts commit e890d888d9.

Reverted https://github.com/pytorch/pytorch/pull/133129 on behalf of https://github.com/fbgheith due to breaking internal tests ([comment](https://github.com/pytorch/pytorch/pull/133129#issuecomment-2285090400))
This commit is contained in:
PyTorch MergeBot 2024-08-12 23:55:08 +00:00
parent 89670d5bdd
commit 00aa086298
4 changed files with 340 additions and 354 deletions

View File

@ -1,19 +1,21 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates # Copyright (c) Meta Platforms, Inc. and affiliates
from typing import Optional, Sequence
# Import all builtin dist tensor ops
import torch import torch
import torch.distributed._tensor.ops as _ops # force import all built-in dtensor ops import torch.distributed._tensor.ops
from torch.distributed._tensor._tensor_constructors import ( import torch.distributed._tensor.random as random
empty, from torch.distributed._tensor._utils import compute_local_shape
full,
ones,
rand,
randn,
zeros,
)
from torch.distributed._tensor.api import distribute_module, distribute_tensor, DTensor from torch.distributed._tensor.api import distribute_module, distribute_tensor, DTensor
from torch.distributed._tensor.placement_types import Partial, Replicate, Shard from torch.distributed._tensor.ops.utils import normalize_to_torch_size
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.distributed._tensor.placement_types import (
Partial,
Placement,
Replicate,
Shard,
)
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh
from torch.optim.optimizer import ( from torch.optim.optimizer import (
_foreach_supported_types as _optim_foreach_supported_types, _foreach_supported_types as _optim_foreach_supported_types,
) )
@ -32,12 +34,6 @@ __all__ = [
"Shard", "Shard",
"Replicate", "Replicate",
"Partial", "Partial",
"ones",
"empty",
"full",
"rand",
"randn",
"zeros",
] ]
@ -48,3 +44,328 @@ if DTensor not in _optim_foreach_supported_types:
if DTensor not in _util_foreach_supported_types: if DTensor not in _util_foreach_supported_types:
_util_foreach_supported_types.append(DTensor) _util_foreach_supported_types.append(DTensor)
def _dtensor_init_helper(
init_op,
size: torch.Size,
device_mesh=None,
placements=None,
**kwargs,
) -> DTensor:
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
# if device_mesh is None, use the one from mesh resources
device_mesh = device_mesh or _mesh_resources.get_current_mesh()
kwargs["device"] = device_mesh.device_type
# set default placements to replicated if not specified
placements = placements or tuple(Replicate() for _ in range(device_mesh.ndim))
# check device_mesh againts placements
assert device_mesh.ndim == len(
placements
), "mesh dimension does not match the length of placements"
assert kwargs["layout"] == torch.strided, "layout value not supported!"
torch_stride = torch._prims_common.make_contiguous_strides_for(size)
# get local tensor shape
local_shape = compute_local_shape(size, device_mesh, placements)
# initialize the local tensor
if init_op == torch.full:
fill_value = kwargs.pop("fill_value", 0)
local_tensor = init_op(local_shape, fill_value, **kwargs)
elif init_op == torch.rand or init_op == torch.randn:
# this tensor meta is not used except `shape`
dtype = kwargs.get("dtype", torch.get_default_dtype())
tensor_meta = TensorMeta(size, (0,), dtype)
spec = DTensorSpec(device_mesh, placements, tensor_meta=tensor_meta)
if random.is_rng_supported_mesh(device_mesh) and not random._rng_tracker:
random._rng_tracker = random.OffsetBasedRNGTracker()
assert random._rng_tracker is not None
with random._rng_tracker._distribute_region(spec):
local_tensor = init_op(local_shape, **kwargs)
else:
local_tensor = init_op(local_shape, **kwargs)
spec = DTensorSpec(
device_mesh,
tuple(placements),
tensor_meta=TensorMeta(
size,
torch_stride,
local_tensor.dtype,
),
)
return DTensor(
local_tensor,
spec,
requires_grad=kwargs["requires_grad"],
)
def ones(
*size,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
requires_grad: bool = False,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
) -> DTensor:
"""
Returns a :class:`DTensor` filled with the scalar value 1, with the shape defined
by the variable argument ``size``.
Args:
size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
Can be a variable number of arguments or a collection like a list or tuple.
E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
layout (:class:`torch.layout`, optional): the desired layout of returned DTensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned :class:`DTensor`. Default: ``False``.
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
Returns:
A :class:`DTensor` object on each rank
"""
torch_size = normalize_to_torch_size(size)
return _dtensor_init_helper(
torch.ones,
torch_size,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
device_mesh=device_mesh,
placements=placements,
)
def empty(
*size,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
requires_grad: bool = False,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
) -> DTensor:
"""
Returns a :class:`DTensor` filled with uninitialized data. The shape of the :class:`DTensor`
is defined by the variable argument ``size``.
Args:
size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
Can be a variable number of arguments or a collection like a list or tuple.
E.g.: empty(1,2,3..) or empty([1,2,3..]) or empty((1,2,3..))
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).\
layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned :class:`DTensor`. Default: ``False``.
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
Returns:
A :class:`DTensor` object on each rank
"""
torch_size = normalize_to_torch_size(size)
return _dtensor_init_helper(
torch.empty,
torch_size,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
device_mesh=device_mesh,
placements=placements,
)
def full(
size,
fill_value,
*,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
requires_grad: bool = False,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
) -> DTensor:
"""
Returns a :class:`DTensor` filled with ``fill_value``. The scalar value type should match
``device_mesh.device_type``.
Args:
size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
Can be a variable number of arguments or a collection like a list or tuple.
E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
fill_value(Scalar): the value to fill the output tensor with.
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
layout (:class:`torch.layout`, optional): the desired layout of returned DTensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned :class:`DTensor`. Default: ``False``.
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks.
placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
Returns:
A :class:`DTensor` object on each rank
"""
torch_size = normalize_to_torch_size(size)
return _dtensor_init_helper(
torch.full,
torch_size,
fill_value=fill_value,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
device_mesh=device_mesh,
placements=placements,
)
def rand(
*size,
requires_grad: bool = False,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
) -> DTensor:
"""
Returns a :class:`DTensor` filled with random numbers from a uniform distribution
on the interval ``[0, 1)``. The shape of the tensor is defined by the variable
argument ``size``.
Args:
size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
Can be a variable number of arguments or a collection like a list or tuple.
E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
layout (:class:`torch.layout`, optional): the desired layout of returned DTensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned :class:`DTensor`. Default: ``False``.
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks.
placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
Returns:
A :class:`DTensor` object on each rank
"""
torch_size = normalize_to_torch_size(size)
return _dtensor_init_helper(
torch.rand,
torch_size,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
device_mesh=device_mesh,
placements=placements,
)
def randn(
*size,
requires_grad: bool = False,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
) -> DTensor:
"""
Returns a :class:`DTensor` filled with random numbers from a normal distribution
with mean 0 and variance 1. The shape of the tensor is defined by the variable
argument ``size``.
Args:
size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
Can be a variable number of arguments or a collection like a list or tuple.
E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
layout (:class:`torch.layout`, optional): the desired layout of returned DTensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned :class:`DTensor`. Default: ``False``.
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks.
placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
Returns:
A :class:`DTensor` object on each rank
"""
torch_size = normalize_to_torch_size(size)
return _dtensor_init_helper(
torch.randn,
torch_size,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
device_mesh=device_mesh,
placements=placements,
)
def zeros(
*size,
requires_grad: bool = False,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
) -> DTensor:
"""
Returns a :class:`DTensor` filled with the scalar value 0.
Args:
size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
Can be a variable number of arguments or a collection like a list or tuple.
E.g.: zeros(1,2,3..) or zeros([1,2,3..]) or zeros((1,2,3..))
Keyword args:
requires_grad (bool, optional): If autograd should record operations on the
returned :class:`DTensor`. Default: ``False``.
dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`.
Default: ``torch.strided``.
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
Returns:
A :class:`DTensor` object on each rank
"""
torch_size = normalize_to_torch_size(size)
return _dtensor_init_helper(
torch.zeros,
torch_size,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
device_mesh=device_mesh,
placements=placements,
)

View File

@ -1,334 +0,0 @@
from typing import Optional, Sequence
import torch
import torch.distributed._tensor.random as random
from torch.distributed._tensor._utils import compute_local_shape
from torch.distributed._tensor.api import DTensor
from torch.distributed._tensor.ops.utils import normalize_to_torch_size
from torch.distributed._tensor.placement_types import Placement, Replicate
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
def _dtensor_init_helper( # type: ignore[no-untyped-def]
init_op,
size: torch.Size,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
**kwargs,
) -> DTensor:
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
# if device_mesh is None, use the one from mesh resources
device_mesh = device_mesh or _mesh_resources.get_current_mesh()
kwargs["device"] = device_mesh.device_type
# set default placements to replicated if not specified
placements = placements or tuple(Replicate() for _ in range(device_mesh.ndim))
# check device_mesh againts placements
assert device_mesh.ndim == len(
placements
), "mesh dimension does not match the length of placements"
assert kwargs["layout"] == torch.strided, "layout value not supported!"
torch_stride = torch._prims_common.make_contiguous_strides_for(size)
# get local tensor shape
local_shape = compute_local_shape(size, device_mesh, placements)
# initialize the local tensor
if init_op == torch.full:
fill_value = kwargs.pop("fill_value", 0)
local_tensor = init_op(local_shape, fill_value, **kwargs)
elif init_op == torch.rand or init_op == torch.randn:
# this tensor meta is not used except `shape`
dtype = kwargs.get("dtype", torch.get_default_dtype())
tensor_meta = TensorMeta(size, (0,), dtype)
spec = DTensorSpec(device_mesh, tuple(placements), tensor_meta=tensor_meta)
if random.is_rng_supported_mesh(device_mesh) and not random._rng_tracker:
random._rng_tracker = random.OffsetBasedRNGTracker()
assert random._rng_tracker is not None
with random._rng_tracker._distribute_region(spec):
local_tensor = init_op(local_shape, **kwargs)
else:
local_tensor = init_op(local_shape, **kwargs)
spec = DTensorSpec(
device_mesh,
tuple(placements),
tensor_meta=TensorMeta(
size,
torch_stride,
local_tensor.dtype,
),
)
return DTensor(
local_tensor,
spec,
requires_grad=kwargs["requires_grad"],
)
def ones( # type: ignore[no-untyped-def]
*size,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
requires_grad: bool = False,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
) -> DTensor:
"""
Returns a :class:`DTensor` filled with the scalar value 1, with the shape defined
by the variable argument ``size``.
Args:
size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
Can be a variable number of arguments or a collection like a list or tuple.
E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
layout (:class:`torch.layout`, optional): the desired layout of returned DTensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned :class:`DTensor`. Default: ``False``.
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
Returns:
A :class:`DTensor` object on each rank
"""
torch_size = normalize_to_torch_size(size)
return _dtensor_init_helper(
torch.ones,
torch_size,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
device_mesh=device_mesh,
placements=placements,
)
def empty( # type: ignore[no-untyped-def]
*size,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
requires_grad: bool = False,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
) -> DTensor:
"""
Returns a :class:`DTensor` filled with uninitialized data. The shape of the :class:`DTensor`
is defined by the variable argument ``size``.
Args:
size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
Can be a variable number of arguments or a collection like a list or tuple.
E.g.: empty(1,2,3..) or empty([1,2,3..]) or empty((1,2,3..))
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).\
layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned :class:`DTensor`. Default: ``False``.
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
Returns:
A :class:`DTensor` object on each rank
"""
torch_size = normalize_to_torch_size(size)
return _dtensor_init_helper(
torch.empty,
torch_size,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
device_mesh=device_mesh,
placements=placements,
)
def full( # type: ignore[no-untyped-def]
size,
fill_value,
*,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
requires_grad: bool = False,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
) -> DTensor:
"""
Returns a :class:`DTensor` filled with ``fill_value``. The scalar value type should match
``device_mesh.device_type``.
Args:
size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
Can be a variable number of arguments or a collection like a list or tuple.
E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
fill_value(Scalar): the value to fill the output tensor with.
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
layout (:class:`torch.layout`, optional): the desired layout of returned DTensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned :class:`DTensor`. Default: ``False``.
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks.
placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
Returns:
A :class:`DTensor` object on each rank
"""
torch_size = normalize_to_torch_size(size)
return _dtensor_init_helper(
torch.full,
torch_size,
fill_value=fill_value,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
device_mesh=device_mesh,
placements=placements,
)
def rand( # type: ignore[no-untyped-def]
*size,
requires_grad: bool = False,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
) -> DTensor:
"""
Returns a :class:`DTensor` filled with random numbers from a uniform distribution
on the interval ``[0, 1)``. The shape of the tensor is defined by the variable
argument ``size``.
Args:
size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
Can be a variable number of arguments or a collection like a list or tuple.
E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
layout (:class:`torch.layout`, optional): the desired layout of returned DTensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned :class:`DTensor`. Default: ``False``.
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks.
placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
Returns:
A :class:`DTensor` object on each rank
"""
torch_size = normalize_to_torch_size(size)
return _dtensor_init_helper(
torch.rand,
torch_size,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
device_mesh=device_mesh,
placements=placements,
)
def randn( # type: ignore[no-untyped-def]
*size,
requires_grad: bool = False,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
) -> DTensor:
"""
Returns a :class:`DTensor` filled with random numbers from a normal distribution
with mean 0 and variance 1. The shape of the tensor is defined by the variable
argument ``size``.
Args:
size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
Can be a variable number of arguments or a collection like a list or tuple.
E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
layout (:class:`torch.layout`, optional): the desired layout of returned DTensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned :class:`DTensor`. Default: ``False``.
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks.
placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
Returns:
A :class:`DTensor` object on each rank
"""
torch_size = normalize_to_torch_size(size)
return _dtensor_init_helper(
torch.randn,
torch_size,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
device_mesh=device_mesh,
placements=placements,
)
def zeros( # type: ignore[no-untyped-def]
*size,
requires_grad: bool = False,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
) -> DTensor:
"""
Returns a :class:`DTensor` filled with the scalar value 0.
Args:
size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
Can be a variable number of arguments or a collection like a list or tuple.
E.g.: zeros(1,2,3..) or zeros([1,2,3..]) or zeros((1,2,3..))
Keyword args:
requires_grad (bool, optional): If autograd should record operations on the
returned :class:`DTensor`. Default: ``False``.
dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`.
Default: ``torch.strided``.
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
Returns:
A :class:`DTensor` object on each rank
"""
torch_size = normalize_to_torch_size(size)
return _dtensor_init_helper(
torch.zeros,
torch_size,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
device_mesh=device_mesh,
placements=placements,
)

View File

@ -1,4 +1,5 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from torch.distributed._tensor.api import DTensor
from torch.distributed._tensor.debug.comm_mode import CommDebugMode from torch.distributed._tensor.debug.comm_mode import CommDebugMode
from torch.distributed._tensor.debug.visualize_sharding import visualize_sharding from torch.distributed._tensor.debug.visualize_sharding import visualize_sharding
@ -12,8 +13,6 @@ def _get_sharding_prop_cache_info():
This would return a named tuple showing hits, misses, maxsize and cursize of the sharding This would return a named tuple showing hits, misses, maxsize and cursize of the sharding
propagator cache. propagator cache.
""" """
from torch.distributed._tensor.api import DTensor
return ( return (
DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding.cache_info() # type:ignore[attr-defined] DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding.cache_info() # type:ignore[attr-defined]
) )

View File

@ -11,10 +11,10 @@ from torch.distributed._tensor import (
distribute_module, distribute_module,
distribute_tensor, distribute_tensor,
DTensor, DTensor,
Placement,
Replicate, Replicate,
Shard, Shard,
) )
from torch.distributed._tensor.placement_types import Placement
__all__ = [ __all__ = [