mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable local tensor mode for DTensor attention and convolution tests (#166406)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166406 Approved by: https://github.com/ezyang
This commit is contained in:
parent
5cbdade914
commit
791ca80d3a
|
|
@ -52,7 +52,9 @@ from torch.testing._internal.common_cuda import (
|
|||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfRocm
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
create_local_tensor_test_class,
|
||||
DTensorTestBase,
|
||||
map_local_tensor_for_rank,
|
||||
with_comms,
|
||||
)
|
||||
|
||||
|
|
@ -800,9 +802,45 @@ class TestSharding(DTensorTestBase):
|
|||
chunks = freqs_cis.chunk(self.world_size * 2)
|
||||
self.assertEqual(
|
||||
freqs_cis_shard,
|
||||
torch.cat(
|
||||
[chunks[self.rank], chunks[self.world_size * 2 - self.rank - 1]], dim=0
|
||||
map_local_tensor_for_rank(
|
||||
chunks,
|
||||
self.rank,
|
||||
lambda chunks, rank: torch.cat(
|
||||
[chunks[rank], chunks[self.world_size * 2 - rank - 1]],
|
||||
dim=0,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
RingAttentionTestWithLocalTensor = create_local_tensor_test_class(
|
||||
RingAttentionTest,
|
||||
skipped_tests=[
|
||||
# Need to make attention implementation local tensor friendly, e.g.
|
||||
# rewrite "rank local" logic
|
||||
"test_ring_attention_sdpa",
|
||||
],
|
||||
)
|
||||
|
||||
CPFlexAttentionTestWithLocalTensor = create_local_tensor_test_class(
|
||||
CPFlexAttentionTest,
|
||||
skipped_tests=[
|
||||
# Missing support for batched tensors
|
||||
"test_cp_flex_attention_causal_mask",
|
||||
"test_cp_flex_attention_document_mask",
|
||||
],
|
||||
)
|
||||
|
||||
TestCPCustomOpsWithLocalTensor = create_local_tensor_test_class(
|
||||
TestCPCustomOps,
|
||||
skipped_tests=[
|
||||
# Missing support for fake tensors
|
||||
"test_flex_cp_custom_op",
|
||||
],
|
||||
)
|
||||
|
||||
TestShardingWithLocalTensor = create_local_tensor_test_class(
|
||||
TestSharding,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from torch.distributed.tensor import (
|
|||
from torch.nn import functional as F
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
create_local_tensor_test_class,
|
||||
DTensorTestBase,
|
||||
skip_if_lt_x_gpu,
|
||||
with_comms,
|
||||
|
|
@ -232,5 +233,17 @@ class DistConvolutionOpsTest(DTensorTestBase):
|
|||
self.assertEqual(out_dt.shape, out.shape)
|
||||
|
||||
|
||||
DistConvolutionOpsTestWithLocalTensor = create_local_tensor_test_class(
|
||||
DistConvolutionOpsTest,
|
||||
# Send / recv ops are not supported
|
||||
skipped_tests=[
|
||||
"test_conv1d",
|
||||
"test_conv3d",
|
||||
"test_conv_backward_none_grad_inp",
|
||||
"test_depthwise_convolution",
|
||||
"test_downsampling_convolution",
|
||||
],
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from ast import Call
|
||||
|
||||
from torch._ops import OpOverload
|
||||
|
||||
|
||||
"""
|
||||
A LocalTensor is a tensor subclass which simulates a tensor that is
|
||||
|
|
@ -65,12 +67,14 @@ import torch
|
|||
from torch import Size, SymBool, SymInt, Tensor
|
||||
from torch._C import DispatchKey, DispatchKeySet, ScriptObject
|
||||
from torch._export.wrappers import mark_subclass_constructor_exportable_experimental
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||
from torch.distributed import DeviceMesh, ProcessGroup
|
||||
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.fx.experimental._constant_symnode import ConstantIntNode
|
||||
from torch.nested._internal.nested_int import NestedIntNode
|
||||
from torch.utils import _pytree as pytree
|
||||
from torch.utils._mode_utils import no_dispatch
|
||||
from torch.utils._python_dispatch import return_and_correct_aliasing, TorchDispatchMode
|
||||
from torch.utils.checkpoint import get_device_states, set_device_states
|
||||
|
||||
|
|
@ -81,6 +85,19 @@ not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemente
|
|||
from . import _c10d
|
||||
|
||||
|
||||
def _is_inplace_op(op: OpOverload | Callable[..., Any]) -> bool:
|
||||
return (
|
||||
isinstance(op, OpOverload)
|
||||
# Not precise heuristic to detect inplace operation
|
||||
and op._schema.name[-1] == "_"
|
||||
# Strengthen the heuristic to check that the first argument and return value are a write
|
||||
and len(op._schema.arguments) > 0
|
||||
and op._schema.arguments[0].is_write
|
||||
and len(op._schema.returns) > 0
|
||||
and op._schema.returns[0].is_write
|
||||
)
|
||||
|
||||
|
||||
def _int_on_rank(i: "int | LocalIntNode | ConstantIntNode", r: int) -> int:
|
||||
if isinstance(i, LocalIntNode):
|
||||
return i._local_ints[r]
|
||||
|
|
@ -100,7 +117,13 @@ def _check_for_subclass_arg(x: object) -> bool:
|
|||
return (
|
||||
not isinstance(x, LocalTensor)
|
||||
and isinstance(x, Tensor)
|
||||
and type(x) not in (Tensor, torch.nn.Parameter, torch.nn.Buffer)
|
||||
and type(x)
|
||||
not in (
|
||||
Tensor,
|
||||
FakeTensor,
|
||||
torch.nn.Parameter,
|
||||
torch.nn.Buffer,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -220,7 +243,7 @@ def _zero_sized_like(tensor: torch.Tensor, dim: int) -> torch.Tensor:
|
|||
|
||||
|
||||
def _for_each_rank_run_func(
|
||||
func: Callable[..., Any],
|
||||
func: OpOverload | Callable[..., Any],
|
||||
ranks: frozenset[int],
|
||||
args: Sequence[Any],
|
||||
kwargs: dict[str, Any],
|
||||
|
|
@ -256,6 +279,14 @@ def _for_each_rank_run_func(
|
|||
split_dim = 0 if len(rank_flat_args) < 3 else rank_flat_args[2]
|
||||
default_value = _zero_sized_like(tensor, split_dim)
|
||||
|
||||
if _is_inplace_op(func):
|
||||
alias = False
|
||||
# For the in-place ops return self
|
||||
ret = args[0]
|
||||
if isinstance(func, OpOverload) and torch.Tag.inplace_view in func.tags:
|
||||
# Ensure that wrapper tensor size is synchronized with its local tensors
|
||||
ret._sync_meta()
|
||||
else:
|
||||
ret = _combine_rank_results(flat_rank_rets, default_value)
|
||||
|
||||
if alias:
|
||||
|
|
@ -386,6 +417,11 @@ class LocalIntNode:
|
|||
r = {self._local_ints[r] != _int_on_rank(other, r) for r in self._local_ints}
|
||||
return torch._C._get_constant_bool_symnode(len(r) > 1 or next(iter(r)))
|
||||
|
||||
def ge(self, other: "int | LocalIntNode | ConstantIntNode") -> bool | SymBool:
|
||||
r = {self._local_ints[r] >= _int_on_rank(other, r) for r in self._local_ints}
|
||||
assert len(r) == 1, (self, other)
|
||||
return torch._C._get_constant_bool_symnode(next(iter(r)))
|
||||
|
||||
def gt(self, other: "int | LocalIntNode | ConstantIntNode") -> bool | SymBool:
|
||||
r = {self._local_ints[r] > _int_on_rank(other, r) for r in self._local_ints}
|
||||
assert len(r) == 1, (self, other)
|
||||
|
|
@ -400,38 +436,43 @@ class LocalIntNode:
|
|||
return ConstantIntNode(num)
|
||||
|
||||
|
||||
class LocalTensor(torch.Tensor):
|
||||
"""
|
||||
LocalTensor is a Tensor subclass that simulates a tensor distributed across multiple SPMD
|
||||
(Single Program, Multiple Data) ranks. Each LocalTensor instance internally holds a mapping from
|
||||
global rank ids to their corresponding local Tensor shards.Operations performed on a LocalTensor
|
||||
are applied independently to each local shard, mimicking distributed computation. Collectives
|
||||
and other distributed operations are handled by mapping them to the local shards as appropriate.
|
||||
_LOCAL_TENSOR_ATTR_PREFIX = "_local_tensor_"
|
||||
|
||||
Note:
|
||||
This class is primarily intended for debugging and simulating distributed tensor computations
|
||||
on a single process.
|
||||
|
||||
"""
|
||||
def _is_local_tensor_attr(attr: str) -> bool:
|
||||
return attr.startswith(_LOCAL_TENSOR_ATTR_PREFIX)
|
||||
|
||||
# Map from global rank to the local tensor.
|
||||
_local_tensors: dict[int, torch.Tensor]
|
||||
# Precomputed for speed set of keys from the local tensor map.
|
||||
_ranks: frozenset[int]
|
||||
__slots__ = ["_local_tensors", "_ranks"]
|
||||
|
||||
@staticmethod
|
||||
@torch._disable_dynamo
|
||||
def __new__(
|
||||
cls,
|
||||
def _to_local_tensor_attr(rank: int) -> str:
|
||||
return f"{_LOCAL_TENSOR_ATTR_PREFIX}{rank}"
|
||||
|
||||
|
||||
def _from_local_tensor_attr(attr: str) -> int:
|
||||
if not _is_local_tensor_attr(attr):
|
||||
raise AssertionError(f"Invalid local tensor attr {attr}")
|
||||
return int(attr[len(_LOCAL_TENSOR_ATTR_PREFIX) :])
|
||||
|
||||
|
||||
def _all_elements_same(values: list[Any]) -> bool:
|
||||
if not values:
|
||||
return True
|
||||
first_value = values[0]
|
||||
return all(value == first_value for value in values)
|
||||
|
||||
|
||||
def _compute_local_tensor_meta(
|
||||
local_tensors: dict[int, torch.Tensor],
|
||||
) -> "LocalTensor":
|
||||
if any(t.requires_grad for t in local_tensors.values()):
|
||||
raise AssertionError(
|
||||
"Internal local_tensors require grad, but we will ignore those autograd graph. "
|
||||
"Make a custom autograd function and make sure you detach the inner tensors."
|
||||
)
|
||||
|
||||
) -> tuple[
|
||||
list[torch.SymInt | int],
|
||||
list[torch.SymInt | int],
|
||||
torch.device,
|
||||
torch.dtype,
|
||||
torch.layout,
|
||||
DispatchKeySet,
|
||||
]:
|
||||
"""
|
||||
Computes the meta information for a LocalTensor from its local tensors.
|
||||
"""
|
||||
it = iter(local_tensors.values())
|
||||
first_local_tensor = next(it)
|
||||
|
||||
|
|
@ -457,12 +498,8 @@ class LocalTensor(torch.Tensor):
|
|||
)
|
||||
|
||||
# Compute shape/stride. We allow for non-SPMD'ness here
|
||||
local_shapes: dict[int, dict[int, int]] = defaultdict(
|
||||
dict
|
||||
) # dim => rank => size
|
||||
local_strides: dict[int, dict[int, int]] = defaultdict(
|
||||
dict
|
||||
) # dim => rank => size
|
||||
local_shapes: dict[int, dict[int, int]] = defaultdict(dict) # dim => rank => size
|
||||
local_strides: dict[int, dict[int, int]] = defaultdict(dict) # dim => rank => size
|
||||
for r, local_tensor in local_tensors.items():
|
||||
for d, size in enumerate(local_tensor.shape):
|
||||
local_shapes[d][r] = size
|
||||
|
|
@ -470,7 +507,7 @@ class LocalTensor(torch.Tensor):
|
|||
shape = [
|
||||
(
|
||||
first_shape[d]
|
||||
if len(set(local_shapes[d])) == 1
|
||||
if _all_elements_same(list(local_shapes[d].values()))
|
||||
else torch.SymInt(LocalIntNode(local_shapes[d]))
|
||||
)
|
||||
for d in range(len(first_shape))
|
||||
|
|
@ -478,11 +515,51 @@ class LocalTensor(torch.Tensor):
|
|||
strides = [
|
||||
(
|
||||
first_stride[d]
|
||||
if len(set(local_strides[d])) == 1
|
||||
if _all_elements_same(list(local_strides[d].values()))
|
||||
else torch.SymInt(LocalIntNode(local_strides[d]))
|
||||
)
|
||||
for d in range(len(first_shape))
|
||||
]
|
||||
return shape, strides, device, dtype, layout, extra_dispatch_keys
|
||||
|
||||
|
||||
class LocalTensor(torch.Tensor):
|
||||
"""
|
||||
LocalTensor is a Tensor subclass that simulates a tensor distributed across multiple SPMD
|
||||
(Single Program, Multiple Data) ranks. Each LocalTensor instance internally holds a mapping from
|
||||
global rank ids to their corresponding local Tensor shards.Operations performed on a LocalTensor
|
||||
are applied independently to each local shard, mimicking distributed computation. Collectives
|
||||
and other distributed operations are handled by mapping them to the local shards as appropriate.
|
||||
|
||||
Note:
|
||||
This class is primarily intended for debugging and simulating distributed tensor computations
|
||||
on a single process.
|
||||
|
||||
"""
|
||||
|
||||
# Map from global rank to the local tensor.
|
||||
_local_tensors: dict[int, torch.Tensor]
|
||||
# Precomputed for speed set of keys from the local tensor map.
|
||||
_ranks: frozenset[int]
|
||||
_size: list[torch.SymInt | int]
|
||||
__slots__ = ["_local_tensors", "_ranks", "_size"]
|
||||
|
||||
@staticmethod
|
||||
@torch._disable_dynamo
|
||||
def __new__(
|
||||
cls,
|
||||
local_tensors: dict[int, torch.Tensor],
|
||||
requires_grad: bool = False,
|
||||
) -> "LocalTensor":
|
||||
if any(t.requires_grad for t in local_tensors.values()):
|
||||
raise AssertionError(
|
||||
"Internal local_tensors require grad, but we will ignore those autograd graph. "
|
||||
"Make a custom autograd function and make sure you detach the inner tensors."
|
||||
)
|
||||
|
||||
(shape, strides, device, dtype, layout, extra_dispatch_keys) = (
|
||||
_compute_local_tensor_meta(local_tensors)
|
||||
)
|
||||
|
||||
r = torch.Tensor._make_wrapper_subclass(
|
||||
cls,
|
||||
|
|
@ -491,7 +568,13 @@ class LocalTensor(torch.Tensor):
|
|||
dtype=dtype,
|
||||
device=device,
|
||||
layout=layout,
|
||||
requires_grad=False,
|
||||
# In place ops potentially change local tensor sizes (e.g. resize_). While
|
||||
# executing an in-place op the return value must be the same as "self" input
|
||||
# otherwise we can introduce errors due to tensor identity changes. Hence we
|
||||
# need to be able to update wrapper subclass sizes after in-place ops. This
|
||||
# dispatch policy allows us to do that.
|
||||
dispatch_sizes_strides_policy="sizes",
|
||||
requires_grad=requires_grad,
|
||||
_extra_dispatch_keys=extra_dispatch_keys,
|
||||
)
|
||||
|
||||
|
|
@ -501,6 +584,7 @@ class LocalTensor(torch.Tensor):
|
|||
}
|
||||
r._local_tensors = local_tensors
|
||||
r._ranks = frozenset(local_tensors.keys())
|
||||
r._size = shape
|
||||
return r
|
||||
|
||||
@torch._disable_dynamo
|
||||
|
|
@ -512,9 +596,7 @@ class LocalTensor(torch.Tensor):
|
|||
local_tensors_copy = {
|
||||
r: copy.deepcopy(t, memo) for r, t in self._local_tensors.items()
|
||||
}
|
||||
tensor_copy = LocalTensor(local_tensors_copy)
|
||||
tensor_copy.requires_grad = self.requires_grad
|
||||
return tensor_copy
|
||||
return LocalTensor(local_tensors_copy, self.requires_grad)
|
||||
|
||||
def __repr__(self) -> str: # type: ignore[override]
|
||||
parts = []
|
||||
|
|
@ -524,12 +606,21 @@ class LocalTensor(torch.Tensor):
|
|||
tensors_str = ",\n".join(parts)
|
||||
return f"LocalTensor(\n{tensors_str}\n)"
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if _is_local_tensor_attr(name):
|
||||
rank = _from_local_tensor_attr(name)
|
||||
if rank not in self._ranks:
|
||||
raise AttributeError(f"Local tensor has no knowledge of rank {rank}")
|
||||
return self._local_tensors[rank]
|
||||
return object.__getattribute__(self, name)
|
||||
|
||||
def __tensor_flatten__(self) -> tuple[list[str], tuple[Any, ...]]:
|
||||
"""
|
||||
protocol to inform how to flatten a DTensor to local tensor
|
||||
for PT2 tracing
|
||||
"""
|
||||
return ["_local_tensors"], ()
|
||||
local_tensor_attrs = [_to_local_tensor_attr(r) for r in self._ranks]
|
||||
return local_tensor_attrs, ()
|
||||
|
||||
@staticmethod
|
||||
def __tensor_unflatten__(
|
||||
|
|
@ -541,8 +632,9 @@ class LocalTensor(torch.Tensor):
|
|||
assert flatten_spec is not None, (
|
||||
"Expecting spec to be not None from `__tensor_flatten__` return value!"
|
||||
)
|
||||
local_tensors = inner_tensors["_local_tensors"]
|
||||
# pyrefly: ignore [bad-argument-type, bad-argument-count]
|
||||
local_tensors = {
|
||||
_from_local_tensor_attr(a): t for a, t in inner_tensors.items()
|
||||
}
|
||||
return LocalTensor(local_tensors)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -591,24 +683,6 @@ class LocalTensor(torch.Tensor):
|
|||
else:
|
||||
raise RuntimeError("Numpy is not available")
|
||||
|
||||
def __lt__(
|
||||
self, other: torch.Tensor | bool | complex | float | int
|
||||
) -> torch.Tensor:
|
||||
self_rec = self.reconcile()
|
||||
other_rec = other
|
||||
if isinstance(other, LocalTensor):
|
||||
other_rec = other.reconcile()
|
||||
return self_rec < other_rec
|
||||
|
||||
def __gt__(
|
||||
self, other: torch.Tensor | bool | complex | float | int
|
||||
) -> torch.Tensor:
|
||||
self_rec = self.reconcile()
|
||||
other_rec = other
|
||||
if isinstance(other, LocalTensor):
|
||||
other_rec = other.reconcile()
|
||||
return self_rec > other_rec
|
||||
|
||||
def contiguous(
|
||||
self,
|
||||
memory_format: torch.memory_format = torch.contiguous_format,
|
||||
|
|
@ -660,6 +734,13 @@ class LocalTensor(torch.Tensor):
|
|||
cl.requires_grad_(self.requires_grad)
|
||||
return cl
|
||||
|
||||
def _sync_meta(self) -> None:
|
||||
with no_dispatch():
|
||||
(shape, strides, device, dtype, layout, extra_dispatch_keys) = (
|
||||
_compute_local_tensor_meta(self._local_tensors)
|
||||
)
|
||||
self._size = shape
|
||||
|
||||
|
||||
_LOCAL_TENSOR_MODE: list["LocalTensorMode"] = []
|
||||
|
||||
|
|
@ -753,6 +834,11 @@ class LocalTensorMode(TorchDispatchMode):
|
|||
f"Input LocalTensor {a} and LocalTensorMode must be configured for the same ranks"
|
||||
)
|
||||
|
||||
if func.overloadpacket == torch.ops.aten.dim:
|
||||
return len(args[0]._size)
|
||||
if func.overloadpacket == torch.ops.aten.sym_size:
|
||||
return tuple(args[0]._size)
|
||||
|
||||
if func.namespace == "c10d":
|
||||
if func is torch.ops.c10d.allreduce_.default:
|
||||
return _c10d._local_all_reduce_(*args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ import torch.distributed as dist
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed._local_tensor import (
|
||||
local_tensor_mode,
|
||||
LocalIntNode,
|
||||
LocalTensor,
|
||||
LocalTensorMode,
|
||||
|
|
@ -715,9 +714,6 @@ class LocalDTensorTestBase(DTensorTestBase):
|
|||
self.skipTest(msg)
|
||||
|
||||
def _get_local_tensor_mode(self):
|
||||
lm = local_tensor_mode()
|
||||
if lm is not None:
|
||||
breakpoint()
|
||||
return LocalTensorMode(frozenset(range(self.world_size)))
|
||||
|
||||
def setUp(self) -> None:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user