Enable local tensor mode for DTensor attention and convolution tests (#166406)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166406
Approved by: https://github.com/ezyang
This commit is contained in:
Dzmitry Huba 2025-10-29 16:45:55 -07:00 committed by PyTorch MergeBot
parent 5cbdade914
commit 791ca80d3a
4 changed files with 219 additions and 86 deletions

View File

@ -52,7 +52,9 @@ from torch.testing._internal.common_cuda import (
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests, skipIfRocm
from torch.testing._internal.distributed._tensor.common_dtensor import (
create_local_tensor_test_class,
DTensorTestBase,
map_local_tensor_for_rank,
with_comms,
)
@ -800,11 +802,47 @@ class TestSharding(DTensorTestBase):
chunks = freqs_cis.chunk(self.world_size * 2)
self.assertEqual(
freqs_cis_shard,
torch.cat(
[chunks[self.rank], chunks[self.world_size * 2 - self.rank - 1]], dim=0
map_local_tensor_for_rank(
chunks,
self.rank,
lambda chunks, rank: torch.cat(
[chunks[rank], chunks[self.world_size * 2 - rank - 1]],
dim=0,
),
),
)
RingAttentionTestWithLocalTensor = create_local_tensor_test_class(
RingAttentionTest,
skipped_tests=[
# Need to make attention implementation local tensor friendly, e.g.
# rewrite "rank local" logic
"test_ring_attention_sdpa",
],
)
CPFlexAttentionTestWithLocalTensor = create_local_tensor_test_class(
CPFlexAttentionTest,
skipped_tests=[
# Missing support for batched tensors
"test_cp_flex_attention_causal_mask",
"test_cp_flex_attention_document_mask",
],
)
TestCPCustomOpsWithLocalTensor = create_local_tensor_test_class(
TestCPCustomOps,
skipped_tests=[
# Missing support for fake tensors
"test_flex_cp_custom_op",
],
)
TestShardingWithLocalTensor = create_local_tensor_test_class(
TestSharding,
)
if __name__ == "__main__":
run_tests()

View File

@ -16,6 +16,7 @@ from torch.distributed.tensor import (
from torch.nn import functional as F
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
create_local_tensor_test_class,
DTensorTestBase,
skip_if_lt_x_gpu,
with_comms,
@ -232,5 +233,17 @@ class DistConvolutionOpsTest(DTensorTestBase):
self.assertEqual(out_dt.shape, out.shape)
DistConvolutionOpsTestWithLocalTensor = create_local_tensor_test_class(
DistConvolutionOpsTest,
# Send / recv ops are not supported
skipped_tests=[
"test_conv1d",
"test_conv3d",
"test_conv_backward_none_grad_inp",
"test_depthwise_convolution",
"test_downsampling_convolution",
],
)
if __name__ == "__main__":
run_tests()

View File

@ -1,5 +1,7 @@
from ast import Call
from torch._ops import OpOverload
"""
A LocalTensor is a tensor subclass which simulates a tensor that is
@ -65,12 +67,14 @@ import torch
from torch import Size, SymBool, SymInt, Tensor
from torch._C import DispatchKey, DispatchKeySet, ScriptObject
from torch._export.wrappers import mark_subclass_constructor_exportable_experimental
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.distributed import DeviceMesh, ProcessGroup
from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed.distributed_c10d import _get_default_group
from torch.fx.experimental._constant_symnode import ConstantIntNode
from torch.nested._internal.nested_int import NestedIntNode
from torch.utils import _pytree as pytree
from torch.utils._mode_utils import no_dispatch
from torch.utils._python_dispatch import return_and_correct_aliasing, TorchDispatchMode
from torch.utils.checkpoint import get_device_states, set_device_states
@ -81,6 +85,19 @@ not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemente
from . import _c10d
def _is_inplace_op(op: OpOverload | Callable[..., Any]) -> bool:
return (
isinstance(op, OpOverload)
# Not precise heuristic to detect inplace operation
and op._schema.name[-1] == "_"
# Strengthen the heuristic to check that the first argument and return value are a write
and len(op._schema.arguments) > 0
and op._schema.arguments[0].is_write
and len(op._schema.returns) > 0
and op._schema.returns[0].is_write
)
def _int_on_rank(i: "int | LocalIntNode | ConstantIntNode", r: int) -> int:
if isinstance(i, LocalIntNode):
return i._local_ints[r]
@ -100,7 +117,13 @@ def _check_for_subclass_arg(x: object) -> bool:
return (
not isinstance(x, LocalTensor)
and isinstance(x, Tensor)
and type(x) not in (Tensor, torch.nn.Parameter, torch.nn.Buffer)
and type(x)
not in (
Tensor,
FakeTensor,
torch.nn.Parameter,
torch.nn.Buffer,
)
)
@ -220,7 +243,7 @@ def _zero_sized_like(tensor: torch.Tensor, dim: int) -> torch.Tensor:
def _for_each_rank_run_func(
func: Callable[..., Any],
func: OpOverload | Callable[..., Any],
ranks: frozenset[int],
args: Sequence[Any],
kwargs: dict[str, Any],
@ -256,7 +279,15 @@ def _for_each_rank_run_func(
split_dim = 0 if len(rank_flat_args) < 3 else rank_flat_args[2]
default_value = _zero_sized_like(tensor, split_dim)
ret = _combine_rank_results(flat_rank_rets, default_value)
if _is_inplace_op(func):
alias = False
# For the in-place ops return self
ret = args[0]
if isinstance(func, OpOverload) and torch.Tag.inplace_view in func.tags:
# Ensure that wrapper tensor size is synchronized with its local tensors
ret._sync_meta()
else:
ret = _combine_rank_results(flat_rank_rets, default_value)
if alias:
return return_and_correct_aliasing(func, args, kwargs, ret)
@ -386,6 +417,11 @@ class LocalIntNode:
r = {self._local_ints[r] != _int_on_rank(other, r) for r in self._local_ints}
return torch._C._get_constant_bool_symnode(len(r) > 1 or next(iter(r)))
def ge(self, other: "int | LocalIntNode | ConstantIntNode") -> bool | SymBool:
r = {self._local_ints[r] >= _int_on_rank(other, r) for r in self._local_ints}
assert len(r) == 1, (self, other)
return torch._C._get_constant_bool_symnode(next(iter(r)))
def gt(self, other: "int | LocalIntNode | ConstantIntNode") -> bool | SymBool:
r = {self._local_ints[r] > _int_on_rank(other, r) for r in self._local_ints}
assert len(r) == 1, (self, other)
@ -400,6 +436,93 @@ class LocalIntNode:
return ConstantIntNode(num)
_LOCAL_TENSOR_ATTR_PREFIX = "_local_tensor_"
def _is_local_tensor_attr(attr: str) -> bool:
return attr.startswith(_LOCAL_TENSOR_ATTR_PREFIX)
def _to_local_tensor_attr(rank: int) -> str:
return f"{_LOCAL_TENSOR_ATTR_PREFIX}{rank}"
def _from_local_tensor_attr(attr: str) -> int:
if not _is_local_tensor_attr(attr):
raise AssertionError(f"Invalid local tensor attr {attr}")
return int(attr[len(_LOCAL_TENSOR_ATTR_PREFIX) :])
def _all_elements_same(values: list[Any]) -> bool:
if not values:
return True
first_value = values[0]
return all(value == first_value for value in values)
def _compute_local_tensor_meta(
local_tensors: dict[int, torch.Tensor],
) -> tuple[
list[torch.SymInt | int],
list[torch.SymInt | int],
torch.device,
torch.dtype,
torch.layout,
DispatchKeySet,
]:
"""
Computes the meta information for a LocalTensor from its local tensors.
"""
it = iter(local_tensors.values())
first_local_tensor = next(it)
first_shape = first_local_tensor.shape
first_stride = first_local_tensor.stride()
dtype = first_local_tensor.dtype
device = first_local_tensor.device
layout = first_local_tensor.layout
extra_dispatch_keys = _get_extra_dispatch_keys(first_local_tensor)
# Assert that all tensors have the same dtype, layout and dispatch keys. Due
# to uneven sharding, it is possible that tensors will have different shapes.
for local_tensor in it:
assert dtype == local_tensor.dtype, (
"Tensors representing LocalTensor shards must have the same dtype"
)
assert layout == local_tensor.layout, (
"Tensors representing LocalTensor shards must have the same layout"
)
assert extra_dispatch_keys == _get_extra_dispatch_keys(local_tensor), (
"Tensors representing LocalTensor shards must have the same set of extra dispatch keys"
)
# Compute shape/stride. We allow for non-SPMD'ness here
local_shapes: dict[int, dict[int, int]] = defaultdict(dict) # dim => rank => size
local_strides: dict[int, dict[int, int]] = defaultdict(dict) # dim => rank => size
for r, local_tensor in local_tensors.items():
for d, size in enumerate(local_tensor.shape):
local_shapes[d][r] = size
local_strides[d][r] = local_tensor.stride(d)
shape = [
(
first_shape[d]
if _all_elements_same(list(local_shapes[d].values()))
else torch.SymInt(LocalIntNode(local_shapes[d]))
)
for d in range(len(first_shape))
]
strides = [
(
first_stride[d]
if _all_elements_same(list(local_strides[d].values()))
else torch.SymInt(LocalIntNode(local_strides[d]))
)
for d in range(len(first_shape))
]
return shape, strides, device, dtype, layout, extra_dispatch_keys
class LocalTensor(torch.Tensor):
"""
LocalTensor is a Tensor subclass that simulates a tensor distributed across multiple SPMD
@ -418,13 +541,15 @@ class LocalTensor(torch.Tensor):
_local_tensors: dict[int, torch.Tensor]
# Precomputed for speed set of keys from the local tensor map.
_ranks: frozenset[int]
__slots__ = ["_local_tensors", "_ranks"]
_size: list[torch.SymInt | int]
__slots__ = ["_local_tensors", "_ranks", "_size"]
@staticmethod
@torch._disable_dynamo
def __new__(
cls,
local_tensors: dict[int, torch.Tensor],
requires_grad: bool = False,
) -> "LocalTensor":
if any(t.requires_grad for t in local_tensors.values()):
raise AssertionError(
@ -432,57 +557,9 @@ class LocalTensor(torch.Tensor):
"Make a custom autograd function and make sure you detach the inner tensors."
)
it = iter(local_tensors.values())
first_local_tensor = next(it)
first_shape = first_local_tensor.shape
first_stride = first_local_tensor.stride()
dtype = first_local_tensor.dtype
device = first_local_tensor.device
layout = first_local_tensor.layout
extra_dispatch_keys = _get_extra_dispatch_keys(first_local_tensor)
# Assert that all tensors have the same dtype, layout and dispatch keys. Due
# to uneven sharding, it is possible that tensors will have different shapes.
for local_tensor in it:
assert dtype == local_tensor.dtype, (
"Tensors representing LocalTensor shards must have the same dtype"
)
assert layout == local_tensor.layout, (
"Tensors representing LocalTensor shards must have the same layout"
)
assert extra_dispatch_keys == _get_extra_dispatch_keys(local_tensor), (
"Tensors representing LocalTensor shards must have the same set of extra dispatch keys"
)
# Compute shape/stride. We allow for non-SPMD'ness here
local_shapes: dict[int, dict[int, int]] = defaultdict(
dict
) # dim => rank => size
local_strides: dict[int, dict[int, int]] = defaultdict(
dict
) # dim => rank => size
for r, local_tensor in local_tensors.items():
for d, size in enumerate(local_tensor.shape):
local_shapes[d][r] = size
local_strides[d][r] = local_tensor.stride(d)
shape = [
(
first_shape[d]
if len(set(local_shapes[d])) == 1
else torch.SymInt(LocalIntNode(local_shapes[d]))
)
for d in range(len(first_shape))
]
strides = [
(
first_stride[d]
if len(set(local_strides[d])) == 1
else torch.SymInt(LocalIntNode(local_strides[d]))
)
for d in range(len(first_shape))
]
(shape, strides, device, dtype, layout, extra_dispatch_keys) = (
_compute_local_tensor_meta(local_tensors)
)
r = torch.Tensor._make_wrapper_subclass(
cls,
@ -491,7 +568,13 @@ class LocalTensor(torch.Tensor):
dtype=dtype,
device=device,
layout=layout,
requires_grad=False,
# In place ops potentially change local tensor sizes (e.g. resize_). While
# executing an in-place op the return value must be the same as "self" input
# otherwise we can introduce errors due to tensor identity changes. Hence we
# need to be able to update wrapper subclass sizes after in-place ops. This
# dispatch policy allows us to do that.
dispatch_sizes_strides_policy="sizes",
requires_grad=requires_grad,
_extra_dispatch_keys=extra_dispatch_keys,
)
@ -501,6 +584,7 @@ class LocalTensor(torch.Tensor):
}
r._local_tensors = local_tensors
r._ranks = frozenset(local_tensors.keys())
r._size = shape
return r
@torch._disable_dynamo
@ -512,9 +596,7 @@ class LocalTensor(torch.Tensor):
local_tensors_copy = {
r: copy.deepcopy(t, memo) for r, t in self._local_tensors.items()
}
tensor_copy = LocalTensor(local_tensors_copy)
tensor_copy.requires_grad = self.requires_grad
return tensor_copy
return LocalTensor(local_tensors_copy, self.requires_grad)
def __repr__(self) -> str: # type: ignore[override]
parts = []
@ -524,12 +606,21 @@ class LocalTensor(torch.Tensor):
tensors_str = ",\n".join(parts)
return f"LocalTensor(\n{tensors_str}\n)"
def __getattr__(self, name: str) -> Any:
if _is_local_tensor_attr(name):
rank = _from_local_tensor_attr(name)
if rank not in self._ranks:
raise AttributeError(f"Local tensor has no knowledge of rank {rank}")
return self._local_tensors[rank]
return object.__getattribute__(self, name)
def __tensor_flatten__(self) -> tuple[list[str], tuple[Any, ...]]:
"""
protocol to inform how to flatten a DTensor to local tensor
for PT2 tracing
"""
return ["_local_tensors"], ()
local_tensor_attrs = [_to_local_tensor_attr(r) for r in self._ranks]
return local_tensor_attrs, ()
@staticmethod
def __tensor_unflatten__(
@ -541,8 +632,9 @@ class LocalTensor(torch.Tensor):
assert flatten_spec is not None, (
"Expecting spec to be not None from `__tensor_flatten__` return value!"
)
local_tensors = inner_tensors["_local_tensors"]
# pyrefly: ignore [bad-argument-type, bad-argument-count]
local_tensors = {
_from_local_tensor_attr(a): t for a, t in inner_tensors.items()
}
return LocalTensor(local_tensors)
@classmethod
@ -591,24 +683,6 @@ class LocalTensor(torch.Tensor):
else:
raise RuntimeError("Numpy is not available")
def __lt__(
self, other: torch.Tensor | bool | complex | float | int
) -> torch.Tensor:
self_rec = self.reconcile()
other_rec = other
if isinstance(other, LocalTensor):
other_rec = other.reconcile()
return self_rec < other_rec
def __gt__(
self, other: torch.Tensor | bool | complex | float | int
) -> torch.Tensor:
self_rec = self.reconcile()
other_rec = other
if isinstance(other, LocalTensor):
other_rec = other.reconcile()
return self_rec > other_rec
def contiguous(
self,
memory_format: torch.memory_format = torch.contiguous_format,
@ -660,6 +734,13 @@ class LocalTensor(torch.Tensor):
cl.requires_grad_(self.requires_grad)
return cl
def _sync_meta(self) -> None:
with no_dispatch():
(shape, strides, device, dtype, layout, extra_dispatch_keys) = (
_compute_local_tensor_meta(self._local_tensors)
)
self._size = shape
_LOCAL_TENSOR_MODE: list["LocalTensorMode"] = []
@ -753,6 +834,11 @@ class LocalTensorMode(TorchDispatchMode):
f"Input LocalTensor {a} and LocalTensorMode must be configured for the same ranks"
)
if func.overloadpacket == torch.ops.aten.dim:
return len(args[0]._size)
if func.overloadpacket == torch.ops.aten.sym_size:
return tuple(args[0]._size)
if func.namespace == "c10d":
if func is torch.ops.c10d.allreduce_.default:
return _c10d._local_all_reduce_(*args, **kwargs)

View File

@ -17,7 +17,6 @@ import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed._local_tensor import (
local_tensor_mode,
LocalIntNode,
LocalTensor,
LocalTensorMode,
@ -715,9 +714,6 @@ class LocalDTensorTestBase(DTensorTestBase):
self.skipTest(msg)
def _get_local_tensor_mode(self):
lm = local_tensor_mode()
if lm is not None:
breakpoint()
return LocalTensorMode(frozenset(range(self.world_size)))
def setUp(self) -> None: