mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable local tensor mode for DTensor attention and convolution tests (#166406)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166406 Approved by: https://github.com/ezyang
This commit is contained in:
parent
5cbdade914
commit
791ca80d3a
|
|
@ -52,7 +52,9 @@ from torch.testing._internal.common_cuda import (
|
||||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||||
from torch.testing._internal.common_utils import run_tests, skipIfRocm
|
from torch.testing._internal.common_utils import run_tests, skipIfRocm
|
||||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||||
|
create_local_tensor_test_class,
|
||||||
DTensorTestBase,
|
DTensorTestBase,
|
||||||
|
map_local_tensor_for_rank,
|
||||||
with_comms,
|
with_comms,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -800,11 +802,47 @@ class TestSharding(DTensorTestBase):
|
||||||
chunks = freqs_cis.chunk(self.world_size * 2)
|
chunks = freqs_cis.chunk(self.world_size * 2)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
freqs_cis_shard,
|
freqs_cis_shard,
|
||||||
torch.cat(
|
map_local_tensor_for_rank(
|
||||||
[chunks[self.rank], chunks[self.world_size * 2 - self.rank - 1]], dim=0
|
chunks,
|
||||||
|
self.rank,
|
||||||
|
lambda chunks, rank: torch.cat(
|
||||||
|
[chunks[rank], chunks[self.world_size * 2 - rank - 1]],
|
||||||
|
dim=0,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
RingAttentionTestWithLocalTensor = create_local_tensor_test_class(
|
||||||
|
RingAttentionTest,
|
||||||
|
skipped_tests=[
|
||||||
|
# Need to make attention implementation local tensor friendly, e.g.
|
||||||
|
# rewrite "rank local" logic
|
||||||
|
"test_ring_attention_sdpa",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
CPFlexAttentionTestWithLocalTensor = create_local_tensor_test_class(
|
||||||
|
CPFlexAttentionTest,
|
||||||
|
skipped_tests=[
|
||||||
|
# Missing support for batched tensors
|
||||||
|
"test_cp_flex_attention_causal_mask",
|
||||||
|
"test_cp_flex_attention_document_mask",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
TestCPCustomOpsWithLocalTensor = create_local_tensor_test_class(
|
||||||
|
TestCPCustomOps,
|
||||||
|
skipped_tests=[
|
||||||
|
# Missing support for fake tensors
|
||||||
|
"test_flex_cp_custom_op",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
TestShardingWithLocalTensor = create_local_tensor_test_class(
|
||||||
|
TestSharding,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ from torch.distributed.tensor import (
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.testing._internal.common_utils import run_tests
|
from torch.testing._internal.common_utils import run_tests
|
||||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||||
|
create_local_tensor_test_class,
|
||||||
DTensorTestBase,
|
DTensorTestBase,
|
||||||
skip_if_lt_x_gpu,
|
skip_if_lt_x_gpu,
|
||||||
with_comms,
|
with_comms,
|
||||||
|
|
@ -232,5 +233,17 @@ class DistConvolutionOpsTest(DTensorTestBase):
|
||||||
self.assertEqual(out_dt.shape, out.shape)
|
self.assertEqual(out_dt.shape, out.shape)
|
||||||
|
|
||||||
|
|
||||||
|
DistConvolutionOpsTestWithLocalTensor = create_local_tensor_test_class(
|
||||||
|
DistConvolutionOpsTest,
|
||||||
|
# Send / recv ops are not supported
|
||||||
|
skipped_tests=[
|
||||||
|
"test_conv1d",
|
||||||
|
"test_conv3d",
|
||||||
|
"test_conv_backward_none_grad_inp",
|
||||||
|
"test_depthwise_convolution",
|
||||||
|
"test_downsampling_convolution",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
from ast import Call
|
from ast import Call
|
||||||
|
|
||||||
|
from torch._ops import OpOverload
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
A LocalTensor is a tensor subclass which simulates a tensor that is
|
A LocalTensor is a tensor subclass which simulates a tensor that is
|
||||||
|
|
@ -65,12 +67,14 @@ import torch
|
||||||
from torch import Size, SymBool, SymInt, Tensor
|
from torch import Size, SymBool, SymInt, Tensor
|
||||||
from torch._C import DispatchKey, DispatchKeySet, ScriptObject
|
from torch._C import DispatchKey, DispatchKeySet, ScriptObject
|
||||||
from torch._export.wrappers import mark_subclass_constructor_exportable_experimental
|
from torch._export.wrappers import mark_subclass_constructor_exportable_experimental
|
||||||
|
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||||
from torch.distributed import DeviceMesh, ProcessGroup
|
from torch.distributed import DeviceMesh, ProcessGroup
|
||||||
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
||||||
from torch.distributed.distributed_c10d import _get_default_group
|
from torch.distributed.distributed_c10d import _get_default_group
|
||||||
from torch.fx.experimental._constant_symnode import ConstantIntNode
|
from torch.fx.experimental._constant_symnode import ConstantIntNode
|
||||||
from torch.nested._internal.nested_int import NestedIntNode
|
from torch.nested._internal.nested_int import NestedIntNode
|
||||||
from torch.utils import _pytree as pytree
|
from torch.utils import _pytree as pytree
|
||||||
|
from torch.utils._mode_utils import no_dispatch
|
||||||
from torch.utils._python_dispatch import return_and_correct_aliasing, TorchDispatchMode
|
from torch.utils._python_dispatch import return_and_correct_aliasing, TorchDispatchMode
|
||||||
from torch.utils.checkpoint import get_device_states, set_device_states
|
from torch.utils.checkpoint import get_device_states, set_device_states
|
||||||
|
|
||||||
|
|
@ -81,6 +85,19 @@ not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemente
|
||||||
from . import _c10d
|
from . import _c10d
|
||||||
|
|
||||||
|
|
||||||
|
def _is_inplace_op(op: OpOverload | Callable[..., Any]) -> bool:
|
||||||
|
return (
|
||||||
|
isinstance(op, OpOverload)
|
||||||
|
# Not precise heuristic to detect inplace operation
|
||||||
|
and op._schema.name[-1] == "_"
|
||||||
|
# Strengthen the heuristic to check that the first argument and return value are a write
|
||||||
|
and len(op._schema.arguments) > 0
|
||||||
|
and op._schema.arguments[0].is_write
|
||||||
|
and len(op._schema.returns) > 0
|
||||||
|
and op._schema.returns[0].is_write
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _int_on_rank(i: "int | LocalIntNode | ConstantIntNode", r: int) -> int:
|
def _int_on_rank(i: "int | LocalIntNode | ConstantIntNode", r: int) -> int:
|
||||||
if isinstance(i, LocalIntNode):
|
if isinstance(i, LocalIntNode):
|
||||||
return i._local_ints[r]
|
return i._local_ints[r]
|
||||||
|
|
@ -100,7 +117,13 @@ def _check_for_subclass_arg(x: object) -> bool:
|
||||||
return (
|
return (
|
||||||
not isinstance(x, LocalTensor)
|
not isinstance(x, LocalTensor)
|
||||||
and isinstance(x, Tensor)
|
and isinstance(x, Tensor)
|
||||||
and type(x) not in (Tensor, torch.nn.Parameter, torch.nn.Buffer)
|
and type(x)
|
||||||
|
not in (
|
||||||
|
Tensor,
|
||||||
|
FakeTensor,
|
||||||
|
torch.nn.Parameter,
|
||||||
|
torch.nn.Buffer,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -220,7 +243,7 @@ def _zero_sized_like(tensor: torch.Tensor, dim: int) -> torch.Tensor:
|
||||||
|
|
||||||
|
|
||||||
def _for_each_rank_run_func(
|
def _for_each_rank_run_func(
|
||||||
func: Callable[..., Any],
|
func: OpOverload | Callable[..., Any],
|
||||||
ranks: frozenset[int],
|
ranks: frozenset[int],
|
||||||
args: Sequence[Any],
|
args: Sequence[Any],
|
||||||
kwargs: dict[str, Any],
|
kwargs: dict[str, Any],
|
||||||
|
|
@ -256,7 +279,15 @@ def _for_each_rank_run_func(
|
||||||
split_dim = 0 if len(rank_flat_args) < 3 else rank_flat_args[2]
|
split_dim = 0 if len(rank_flat_args) < 3 else rank_flat_args[2]
|
||||||
default_value = _zero_sized_like(tensor, split_dim)
|
default_value = _zero_sized_like(tensor, split_dim)
|
||||||
|
|
||||||
ret = _combine_rank_results(flat_rank_rets, default_value)
|
if _is_inplace_op(func):
|
||||||
|
alias = False
|
||||||
|
# For the in-place ops return self
|
||||||
|
ret = args[0]
|
||||||
|
if isinstance(func, OpOverload) and torch.Tag.inplace_view in func.tags:
|
||||||
|
# Ensure that wrapper tensor size is synchronized with its local tensors
|
||||||
|
ret._sync_meta()
|
||||||
|
else:
|
||||||
|
ret = _combine_rank_results(flat_rank_rets, default_value)
|
||||||
|
|
||||||
if alias:
|
if alias:
|
||||||
return return_and_correct_aliasing(func, args, kwargs, ret)
|
return return_and_correct_aliasing(func, args, kwargs, ret)
|
||||||
|
|
@ -386,6 +417,11 @@ class LocalIntNode:
|
||||||
r = {self._local_ints[r] != _int_on_rank(other, r) for r in self._local_ints}
|
r = {self._local_ints[r] != _int_on_rank(other, r) for r in self._local_ints}
|
||||||
return torch._C._get_constant_bool_symnode(len(r) > 1 or next(iter(r)))
|
return torch._C._get_constant_bool_symnode(len(r) > 1 or next(iter(r)))
|
||||||
|
|
||||||
|
def ge(self, other: "int | LocalIntNode | ConstantIntNode") -> bool | SymBool:
|
||||||
|
r = {self._local_ints[r] >= _int_on_rank(other, r) for r in self._local_ints}
|
||||||
|
assert len(r) == 1, (self, other)
|
||||||
|
return torch._C._get_constant_bool_symnode(next(iter(r)))
|
||||||
|
|
||||||
def gt(self, other: "int | LocalIntNode | ConstantIntNode") -> bool | SymBool:
|
def gt(self, other: "int | LocalIntNode | ConstantIntNode") -> bool | SymBool:
|
||||||
r = {self._local_ints[r] > _int_on_rank(other, r) for r in self._local_ints}
|
r = {self._local_ints[r] > _int_on_rank(other, r) for r in self._local_ints}
|
||||||
assert len(r) == 1, (self, other)
|
assert len(r) == 1, (self, other)
|
||||||
|
|
@ -400,6 +436,93 @@ class LocalIntNode:
|
||||||
return ConstantIntNode(num)
|
return ConstantIntNode(num)
|
||||||
|
|
||||||
|
|
||||||
|
_LOCAL_TENSOR_ATTR_PREFIX = "_local_tensor_"
|
||||||
|
|
||||||
|
|
||||||
|
def _is_local_tensor_attr(attr: str) -> bool:
|
||||||
|
return attr.startswith(_LOCAL_TENSOR_ATTR_PREFIX)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_local_tensor_attr(rank: int) -> str:
|
||||||
|
return f"{_LOCAL_TENSOR_ATTR_PREFIX}{rank}"
|
||||||
|
|
||||||
|
|
||||||
|
def _from_local_tensor_attr(attr: str) -> int:
|
||||||
|
if not _is_local_tensor_attr(attr):
|
||||||
|
raise AssertionError(f"Invalid local tensor attr {attr}")
|
||||||
|
return int(attr[len(_LOCAL_TENSOR_ATTR_PREFIX) :])
|
||||||
|
|
||||||
|
|
||||||
|
def _all_elements_same(values: list[Any]) -> bool:
|
||||||
|
if not values:
|
||||||
|
return True
|
||||||
|
first_value = values[0]
|
||||||
|
return all(value == first_value for value in values)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_local_tensor_meta(
|
||||||
|
local_tensors: dict[int, torch.Tensor],
|
||||||
|
) -> tuple[
|
||||||
|
list[torch.SymInt | int],
|
||||||
|
list[torch.SymInt | int],
|
||||||
|
torch.device,
|
||||||
|
torch.dtype,
|
||||||
|
torch.layout,
|
||||||
|
DispatchKeySet,
|
||||||
|
]:
|
||||||
|
"""
|
||||||
|
Computes the meta information for a LocalTensor from its local tensors.
|
||||||
|
"""
|
||||||
|
it = iter(local_tensors.values())
|
||||||
|
first_local_tensor = next(it)
|
||||||
|
|
||||||
|
first_shape = first_local_tensor.shape
|
||||||
|
first_stride = first_local_tensor.stride()
|
||||||
|
dtype = first_local_tensor.dtype
|
||||||
|
device = first_local_tensor.device
|
||||||
|
layout = first_local_tensor.layout
|
||||||
|
|
||||||
|
extra_dispatch_keys = _get_extra_dispatch_keys(first_local_tensor)
|
||||||
|
|
||||||
|
# Assert that all tensors have the same dtype, layout and dispatch keys. Due
|
||||||
|
# to uneven sharding, it is possible that tensors will have different shapes.
|
||||||
|
for local_tensor in it:
|
||||||
|
assert dtype == local_tensor.dtype, (
|
||||||
|
"Tensors representing LocalTensor shards must have the same dtype"
|
||||||
|
)
|
||||||
|
assert layout == local_tensor.layout, (
|
||||||
|
"Tensors representing LocalTensor shards must have the same layout"
|
||||||
|
)
|
||||||
|
assert extra_dispatch_keys == _get_extra_dispatch_keys(local_tensor), (
|
||||||
|
"Tensors representing LocalTensor shards must have the same set of extra dispatch keys"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute shape/stride. We allow for non-SPMD'ness here
|
||||||
|
local_shapes: dict[int, dict[int, int]] = defaultdict(dict) # dim => rank => size
|
||||||
|
local_strides: dict[int, dict[int, int]] = defaultdict(dict) # dim => rank => size
|
||||||
|
for r, local_tensor in local_tensors.items():
|
||||||
|
for d, size in enumerate(local_tensor.shape):
|
||||||
|
local_shapes[d][r] = size
|
||||||
|
local_strides[d][r] = local_tensor.stride(d)
|
||||||
|
shape = [
|
||||||
|
(
|
||||||
|
first_shape[d]
|
||||||
|
if _all_elements_same(list(local_shapes[d].values()))
|
||||||
|
else torch.SymInt(LocalIntNode(local_shapes[d]))
|
||||||
|
)
|
||||||
|
for d in range(len(first_shape))
|
||||||
|
]
|
||||||
|
strides = [
|
||||||
|
(
|
||||||
|
first_stride[d]
|
||||||
|
if _all_elements_same(list(local_strides[d].values()))
|
||||||
|
else torch.SymInt(LocalIntNode(local_strides[d]))
|
||||||
|
)
|
||||||
|
for d in range(len(first_shape))
|
||||||
|
]
|
||||||
|
return shape, strides, device, dtype, layout, extra_dispatch_keys
|
||||||
|
|
||||||
|
|
||||||
class LocalTensor(torch.Tensor):
|
class LocalTensor(torch.Tensor):
|
||||||
"""
|
"""
|
||||||
LocalTensor is a Tensor subclass that simulates a tensor distributed across multiple SPMD
|
LocalTensor is a Tensor subclass that simulates a tensor distributed across multiple SPMD
|
||||||
|
|
@ -418,13 +541,15 @@ class LocalTensor(torch.Tensor):
|
||||||
_local_tensors: dict[int, torch.Tensor]
|
_local_tensors: dict[int, torch.Tensor]
|
||||||
# Precomputed for speed set of keys from the local tensor map.
|
# Precomputed for speed set of keys from the local tensor map.
|
||||||
_ranks: frozenset[int]
|
_ranks: frozenset[int]
|
||||||
__slots__ = ["_local_tensors", "_ranks"]
|
_size: list[torch.SymInt | int]
|
||||||
|
__slots__ = ["_local_tensors", "_ranks", "_size"]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@torch._disable_dynamo
|
@torch._disable_dynamo
|
||||||
def __new__(
|
def __new__(
|
||||||
cls,
|
cls,
|
||||||
local_tensors: dict[int, torch.Tensor],
|
local_tensors: dict[int, torch.Tensor],
|
||||||
|
requires_grad: bool = False,
|
||||||
) -> "LocalTensor":
|
) -> "LocalTensor":
|
||||||
if any(t.requires_grad for t in local_tensors.values()):
|
if any(t.requires_grad for t in local_tensors.values()):
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
|
|
@ -432,57 +557,9 @@ class LocalTensor(torch.Tensor):
|
||||||
"Make a custom autograd function and make sure you detach the inner tensors."
|
"Make a custom autograd function and make sure you detach the inner tensors."
|
||||||
)
|
)
|
||||||
|
|
||||||
it = iter(local_tensors.values())
|
(shape, strides, device, dtype, layout, extra_dispatch_keys) = (
|
||||||
first_local_tensor = next(it)
|
_compute_local_tensor_meta(local_tensors)
|
||||||
|
)
|
||||||
first_shape = first_local_tensor.shape
|
|
||||||
first_stride = first_local_tensor.stride()
|
|
||||||
dtype = first_local_tensor.dtype
|
|
||||||
device = first_local_tensor.device
|
|
||||||
layout = first_local_tensor.layout
|
|
||||||
|
|
||||||
extra_dispatch_keys = _get_extra_dispatch_keys(first_local_tensor)
|
|
||||||
|
|
||||||
# Assert that all tensors have the same dtype, layout and dispatch keys. Due
|
|
||||||
# to uneven sharding, it is possible that tensors will have different shapes.
|
|
||||||
for local_tensor in it:
|
|
||||||
assert dtype == local_tensor.dtype, (
|
|
||||||
"Tensors representing LocalTensor shards must have the same dtype"
|
|
||||||
)
|
|
||||||
assert layout == local_tensor.layout, (
|
|
||||||
"Tensors representing LocalTensor shards must have the same layout"
|
|
||||||
)
|
|
||||||
assert extra_dispatch_keys == _get_extra_dispatch_keys(local_tensor), (
|
|
||||||
"Tensors representing LocalTensor shards must have the same set of extra dispatch keys"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute shape/stride. We allow for non-SPMD'ness here
|
|
||||||
local_shapes: dict[int, dict[int, int]] = defaultdict(
|
|
||||||
dict
|
|
||||||
) # dim => rank => size
|
|
||||||
local_strides: dict[int, dict[int, int]] = defaultdict(
|
|
||||||
dict
|
|
||||||
) # dim => rank => size
|
|
||||||
for r, local_tensor in local_tensors.items():
|
|
||||||
for d, size in enumerate(local_tensor.shape):
|
|
||||||
local_shapes[d][r] = size
|
|
||||||
local_strides[d][r] = local_tensor.stride(d)
|
|
||||||
shape = [
|
|
||||||
(
|
|
||||||
first_shape[d]
|
|
||||||
if len(set(local_shapes[d])) == 1
|
|
||||||
else torch.SymInt(LocalIntNode(local_shapes[d]))
|
|
||||||
)
|
|
||||||
for d in range(len(first_shape))
|
|
||||||
]
|
|
||||||
strides = [
|
|
||||||
(
|
|
||||||
first_stride[d]
|
|
||||||
if len(set(local_strides[d])) == 1
|
|
||||||
else torch.SymInt(LocalIntNode(local_strides[d]))
|
|
||||||
)
|
|
||||||
for d in range(len(first_shape))
|
|
||||||
]
|
|
||||||
|
|
||||||
r = torch.Tensor._make_wrapper_subclass(
|
r = torch.Tensor._make_wrapper_subclass(
|
||||||
cls,
|
cls,
|
||||||
|
|
@ -491,7 +568,13 @@ class LocalTensor(torch.Tensor):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
layout=layout,
|
layout=layout,
|
||||||
requires_grad=False,
|
# In place ops potentially change local tensor sizes (e.g. resize_). While
|
||||||
|
# executing an in-place op the return value must be the same as "self" input
|
||||||
|
# otherwise we can introduce errors due to tensor identity changes. Hence we
|
||||||
|
# need to be able to update wrapper subclass sizes after in-place ops. This
|
||||||
|
# dispatch policy allows us to do that.
|
||||||
|
dispatch_sizes_strides_policy="sizes",
|
||||||
|
requires_grad=requires_grad,
|
||||||
_extra_dispatch_keys=extra_dispatch_keys,
|
_extra_dispatch_keys=extra_dispatch_keys,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -501,6 +584,7 @@ class LocalTensor(torch.Tensor):
|
||||||
}
|
}
|
||||||
r._local_tensors = local_tensors
|
r._local_tensors = local_tensors
|
||||||
r._ranks = frozenset(local_tensors.keys())
|
r._ranks = frozenset(local_tensors.keys())
|
||||||
|
r._size = shape
|
||||||
return r
|
return r
|
||||||
|
|
||||||
@torch._disable_dynamo
|
@torch._disable_dynamo
|
||||||
|
|
@ -512,9 +596,7 @@ class LocalTensor(torch.Tensor):
|
||||||
local_tensors_copy = {
|
local_tensors_copy = {
|
||||||
r: copy.deepcopy(t, memo) for r, t in self._local_tensors.items()
|
r: copy.deepcopy(t, memo) for r, t in self._local_tensors.items()
|
||||||
}
|
}
|
||||||
tensor_copy = LocalTensor(local_tensors_copy)
|
return LocalTensor(local_tensors_copy, self.requires_grad)
|
||||||
tensor_copy.requires_grad = self.requires_grad
|
|
||||||
return tensor_copy
|
|
||||||
|
|
||||||
def __repr__(self) -> str: # type: ignore[override]
|
def __repr__(self) -> str: # type: ignore[override]
|
||||||
parts = []
|
parts = []
|
||||||
|
|
@ -524,12 +606,21 @@ class LocalTensor(torch.Tensor):
|
||||||
tensors_str = ",\n".join(parts)
|
tensors_str = ",\n".join(parts)
|
||||||
return f"LocalTensor(\n{tensors_str}\n)"
|
return f"LocalTensor(\n{tensors_str}\n)"
|
||||||
|
|
||||||
|
def __getattr__(self, name: str) -> Any:
|
||||||
|
if _is_local_tensor_attr(name):
|
||||||
|
rank = _from_local_tensor_attr(name)
|
||||||
|
if rank not in self._ranks:
|
||||||
|
raise AttributeError(f"Local tensor has no knowledge of rank {rank}")
|
||||||
|
return self._local_tensors[rank]
|
||||||
|
return object.__getattribute__(self, name)
|
||||||
|
|
||||||
def __tensor_flatten__(self) -> tuple[list[str], tuple[Any, ...]]:
|
def __tensor_flatten__(self) -> tuple[list[str], tuple[Any, ...]]:
|
||||||
"""
|
"""
|
||||||
protocol to inform how to flatten a DTensor to local tensor
|
protocol to inform how to flatten a DTensor to local tensor
|
||||||
for PT2 tracing
|
for PT2 tracing
|
||||||
"""
|
"""
|
||||||
return ["_local_tensors"], ()
|
local_tensor_attrs = [_to_local_tensor_attr(r) for r in self._ranks]
|
||||||
|
return local_tensor_attrs, ()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __tensor_unflatten__(
|
def __tensor_unflatten__(
|
||||||
|
|
@ -541,8 +632,9 @@ class LocalTensor(torch.Tensor):
|
||||||
assert flatten_spec is not None, (
|
assert flatten_spec is not None, (
|
||||||
"Expecting spec to be not None from `__tensor_flatten__` return value!"
|
"Expecting spec to be not None from `__tensor_flatten__` return value!"
|
||||||
)
|
)
|
||||||
local_tensors = inner_tensors["_local_tensors"]
|
local_tensors = {
|
||||||
# pyrefly: ignore [bad-argument-type, bad-argument-count]
|
_from_local_tensor_attr(a): t for a, t in inner_tensors.items()
|
||||||
|
}
|
||||||
return LocalTensor(local_tensors)
|
return LocalTensor(local_tensors)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -591,24 +683,6 @@ class LocalTensor(torch.Tensor):
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Numpy is not available")
|
raise RuntimeError("Numpy is not available")
|
||||||
|
|
||||||
def __lt__(
|
|
||||||
self, other: torch.Tensor | bool | complex | float | int
|
|
||||||
) -> torch.Tensor:
|
|
||||||
self_rec = self.reconcile()
|
|
||||||
other_rec = other
|
|
||||||
if isinstance(other, LocalTensor):
|
|
||||||
other_rec = other.reconcile()
|
|
||||||
return self_rec < other_rec
|
|
||||||
|
|
||||||
def __gt__(
|
|
||||||
self, other: torch.Tensor | bool | complex | float | int
|
|
||||||
) -> torch.Tensor:
|
|
||||||
self_rec = self.reconcile()
|
|
||||||
other_rec = other
|
|
||||||
if isinstance(other, LocalTensor):
|
|
||||||
other_rec = other.reconcile()
|
|
||||||
return self_rec > other_rec
|
|
||||||
|
|
||||||
def contiguous(
|
def contiguous(
|
||||||
self,
|
self,
|
||||||
memory_format: torch.memory_format = torch.contiguous_format,
|
memory_format: torch.memory_format = torch.contiguous_format,
|
||||||
|
|
@ -660,6 +734,13 @@ class LocalTensor(torch.Tensor):
|
||||||
cl.requires_grad_(self.requires_grad)
|
cl.requires_grad_(self.requires_grad)
|
||||||
return cl
|
return cl
|
||||||
|
|
||||||
|
def _sync_meta(self) -> None:
|
||||||
|
with no_dispatch():
|
||||||
|
(shape, strides, device, dtype, layout, extra_dispatch_keys) = (
|
||||||
|
_compute_local_tensor_meta(self._local_tensors)
|
||||||
|
)
|
||||||
|
self._size = shape
|
||||||
|
|
||||||
|
|
||||||
_LOCAL_TENSOR_MODE: list["LocalTensorMode"] = []
|
_LOCAL_TENSOR_MODE: list["LocalTensorMode"] = []
|
||||||
|
|
||||||
|
|
@ -753,6 +834,11 @@ class LocalTensorMode(TorchDispatchMode):
|
||||||
f"Input LocalTensor {a} and LocalTensorMode must be configured for the same ranks"
|
f"Input LocalTensor {a} and LocalTensorMode must be configured for the same ranks"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if func.overloadpacket == torch.ops.aten.dim:
|
||||||
|
return len(args[0]._size)
|
||||||
|
if func.overloadpacket == torch.ops.aten.sym_size:
|
||||||
|
return tuple(args[0]._size)
|
||||||
|
|
||||||
if func.namespace == "c10d":
|
if func.namespace == "c10d":
|
||||||
if func is torch.ops.c10d.allreduce_.default:
|
if func is torch.ops.c10d.allreduce_.default:
|
||||||
return _c10d._local_all_reduce_(*args, **kwargs)
|
return _c10d._local_all_reduce_(*args, **kwargs)
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@ import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.distributed._local_tensor import (
|
from torch.distributed._local_tensor import (
|
||||||
local_tensor_mode,
|
|
||||||
LocalIntNode,
|
LocalIntNode,
|
||||||
LocalTensor,
|
LocalTensor,
|
||||||
LocalTensorMode,
|
LocalTensorMode,
|
||||||
|
|
@ -715,9 +714,6 @@ class LocalDTensorTestBase(DTensorTestBase):
|
||||||
self.skipTest(msg)
|
self.skipTest(msg)
|
||||||
|
|
||||||
def _get_local_tensor_mode(self):
|
def _get_local_tensor_mode(self):
|
||||||
lm = local_tensor_mode()
|
|
||||||
if lm is not None:
|
|
||||||
breakpoint()
|
|
||||||
return LocalTensorMode(frozenset(range(self.world_size)))
|
return LocalTensorMode(frozenset(range(self.world_size)))
|
||||||
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user