pytorch/test/distributed/tensor/test_random_ops.py
Tianyu Liu 92f7361e27 [DTensor] fix uneven _StridedShard (#163843)
Previous uneven `_StridedShard` in https://github.com/pytorch/pytorch/pull/150490 seems failing cases like sharding `tensor = torch.arange(6)` with FSDP 2, TP 2.

This PR attempts to reinvent `_StridedShard`.

I didn't test nested `_StridedShard`, because there shouldn't be any use cases. I think it will become quite messy when it comes to **nested uneven** `_StridedShard`. We are probably going to deprecate it anyway after @zpcore 's work https://github.com/pytorch/pytorch/pull/160266 on ordered sharding, so IMO not worth it to make it too general.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163843
Approved by: https://github.com/ezyang
2025-09-25 22:12:29 +00:00

663 lines
26 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import itertools
import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.tensor._random as random
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.distributed_c10d import broadcast_object_list
from torch.distributed.fsdp import fully_shard
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
Replicate,
Shard,
)
from torch.distributed.tensor._random import (
is_rng_supported_mesh,
manual_seed,
OffsetBasedRNGTracker,
)
from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_if_lt_x_gpu,
skip_unless_torch_gpu,
with_comms,
)
from torch.utils._typing_utils import not_none
def get_generator_seed_for_device_type(device_type: str) -> int:
device_module = torch.get_device_module(device_type)
return device_module.get_rng_state()[:8].view(torch.int64).item()
class DistTensorRandomInitTest(DTensorTestBase):
def _run_init_op(self, init_op, *args, **kwargs):
device_mesh = self.build_device_mesh()
shard_spec = [Shard(0)]
input_size = (8, 4)
# NOTE: currently random initialization on gpu device has different
# behavior from other devices. Unify the test once the behavior is unified.
if not is_rng_supported_mesh(device_mesh):
input_tensor = torch.randn(*input_size, device=self.device_type)
dtensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
local_tensor_clone = torch.clone(input_tensor)
torch.manual_seed(self.rank)
local_tensor_clone = init_op(local_tensor_clone, *args, **kwargs)
torch.manual_seed(self.rank)
dtensor = init_op(dtensor, *args, **kwargs)
self.assertEqual(local_tensor_clone, dtensor.to_local())
else:
# create DTensor from Tensor
_tensor = torch.empty(*input_size, device=self.device_type)
dtensor = distribute_tensor(_tensor, device_mesh, [Shard(1)])
# DTensor random init
dtensor = init_op(dtensor, *args, **kwargs)
local_tensor = dtensor.to_local()
# compare with local tensors from other ranks
for other_rank in range(self.world_size):
if self.rank != other_rank:
slice_idx = (
slice(input_size[0]),
slice(
other_rank * input_size[1], (other_rank + 1) * input_size[1]
),
)
# other rank should have a different local tensor
self.assertNotEqual(dtensor.full_tensor()[slice_idx], local_tensor)
@with_comms
def test_init_ops(self):
self._run_init_op(
torch.nn.init.kaiming_uniform_,
a=0,
mode="fan_in",
nonlinearity="leaky_relu",
)
self._run_init_op(torch.nn.init.normal_, mean=1.5, std=0.8)
self._run_init_op(torch.nn.init.uniform_, a=0, b=1.2)
for dtype in (torch.float32, torch.float16):
self._run_init_op(torch.rand_like, dtype=dtype)
self._run_init_op(torch.randn_like, dtype=dtype)
self._run_init_op(torch.randint_like, low=0, high=100, dtype=dtype)
@with_comms
@skip_if_lt_x_gpu(4)
def test_init_with_user_generator(self):
device_mesh = self.build_device_mesh()
torch.manual_seed(42)
rng = torch.Generator(device=self.device_type).manual_seed(42)
t1 = torch.distributed.tensor.empty(
(8, 3), device_mesh=device_mesh, placements=[Shard(0)]
)
t2 = torch.distributed.tensor.empty(
(8, 3), device_mesh=device_mesh, placements=[Shard(0)]
)
for i in range(2):
# run a second time, to make sure that `rng`'s offset-state is advancing on the second usage
torch.nn.init.uniform_(t1, 0.0, 1.0)
torch.nn.init.uniform_(t2, 0.0, 1.0, rng)
self.assertEqual(t1.full_tensor(), t2.full_tensor(), f"Failed at {i=}")
# ensure that we do not cache the 'seed' from the first time we see it in DTensor
# this is a behavior change, DTensor used to cache the generator state and not modify the original generator,
# now it modifies the original generator instead.
torch.manual_seed(55)
rng.manual_seed(55)
torch.nn.init.uniform_(t1, 0.0, 1.0)
torch.nn.init.uniform_(t2, 0.0, 1.0, rng)
self.assertEqual(t1.full_tensor(), t2.full_tensor())
@with_comms
@skip_if_lt_x_gpu(4)
def test_meta_tensor_init(self):
# test suite sets each rank's seed to the same value.
# The DTensor random ops will use the same generator as the default one on the device.
# Note: this behavior changed, and now the guideline is to set the same RNG seed on all SPMD ranks.
torch.get_device_module(self.device_type).manual_seed(0)
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
size = [1024, 2048]
meta_dtensor = distribute_tensor(
torch.empty(*size, device="meta"), device_mesh, [Replicate()]
)
# the tensor slice on the current rank
self_slice = slice(1024 * self.rank, 1024 * self.rank + 1024)
# Test 1: enable the distribute region for RNG (by default)
self.assertTrue(meta_dtensor.is_meta)
# Tensor meta init
dtensor = torch.empty_like(meta_dtensor, device=self.device_type)
dtensor.uniform_()
# check `distribute_region_enabled` is set to True by default
self.assertTrue(random._rng_tracker.distribute_region_enabled)
# allgather the local tensors
gathered_local_tensors = funcol.all_gather_tensor(
dtensor.to_local(), gather_dim=0, group=(device_mesh, 0)
)
# compare with local tensors from other ranks
for other_rank in range(self.world_size):
# the RNG result on each rank are the same because they're replicated
if self.rank != other_rank:
# other rank should have an identical local tensor
other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024)
self.assertEqual(
gathered_local_tensors[self_slice, :],
gathered_local_tensors[other_slice, :],
)
# Test 2: disable the distribute region for RNG
self.assertTrue(meta_dtensor.is_meta)
# Tensor meta init
dtensor = torch.empty_like(meta_dtensor, device=self.device_type)
random._rng_tracker.distribute_region_enabled = False
dtensor.uniform_()
# check `distribute_region_enabled` is set to False
self.assertTrue(not random._rng_tracker.distribute_region_enabled)
# allgather the local tensors
local_tensor = funcol.all_gather_tensor(
dtensor.to_local(), gather_dim=0, group=(device_mesh, 0)
)
# compare with local tensors from other ranks
for other_rank in range(self.world_size):
# the RNG result on each rank are the same even without the help of DTensor's RNG infra,
# since the default RNG is the same across ranks.
if self.rank != other_rank:
other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024)
self.assertEqual(
local_tensor[self_slice, :], local_tensor[other_slice, :]
)
@with_comms
@skip_unless_torch_gpu
def test_tp_model_meta_init(self):
# initialize the 1-d device mesh for TP
tp_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
# model meta init
with torch.device("meta"):
model = torch.nn.Linear(self.world_size, self.world_size, bias=False)
self.assertEqual(model.weight.device, torch.device("meta"))
parallelize_module(model, tp_mesh, ColwiseParallel())
if random._rng_tracker is not None:
random._rng_tracker.distribute_region_enabled = True
self.assertEqual(model.weight.device, torch.device("meta"))
# actual initialization
device = torch.device(
self.device_type, torch.get_device_module(self.device_type).current_device()
)
model.to_empty(device=device)
model.reset_parameters()
self.assertTrue(
random._rng_tracker is not None
and isinstance(random._rng_tracker, OffsetBasedRNGTracker)
)
self.assertEqual(model.weight.device, device)
assert isinstance(model.weight, DTensor)
# gather all the shards to compare initialization results
WORLD = torch.distributed.group.WORLD
assert WORLD is not None
weight_local = model.weight.to_local()
weight_gather = funcol.all_gather_tensor(
weight_local,
gather_dim=0,
group=WORLD,
)
# verify the weights are initialized differently on all ranks
for other_rank in range(self.world_size):
if self.rank != other_rank:
self.assertNotEqual(
weight_local,
weight_gather[other_rank : other_rank + 1, :],
)
@with_comms
@skip_if_lt_x_gpu(4)
def test_fsdp_tp_model_meta_init(self):
# initialize the 2-d device mesh
global_mesh = init_device_mesh(
self.device_type,
mesh_shape=(self.world_size // 2, 2),
mesh_dim_names=("dp", "tp"),
)
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
# model meta init
with torch.device("meta"):
model = torch.nn.Linear(self.world_size, self.world_size, bias=False)
self.assertEqual(model.weight.device, torch.device("meta"))
parallelize_module(model, tp_mesh, ColwiseParallel())
if random._rng_tracker is not None:
random._rng_tracker.distribute_region_enabled = True
fully_shard(model, mesh=dp_mesh)
self.assertEqual(model.weight.device, torch.device("meta"))
# actual initialization
device = torch.device(
self.device_type, torch.get_device_module(self.device_type).current_device()
)
model.to_empty(device=device)
model.reset_parameters()
self.assertTrue(
random._rng_tracker is not None
and isinstance(random._rng_tracker, OffsetBasedRNGTracker)
)
self.assertEqual(model.weight.device, device)
assert isinstance(model.weight, DTensor)
# gather all the shards to compare initialization results
WORLD = torch.distributed.group.WORLD
assert WORLD is not None
weight_local = model.weight.to_local()
weight_gather = funcol.all_gather_tensor(
weight_local,
gather_dim=0,
group=WORLD,
)
# verify the weights are initialized differently on all ranks
for other_rank in range(self.world_size):
if self.rank != other_rank:
self.assertNotEqual(
weight_local,
weight_gather[other_rank : other_rank + 1, :],
)
class DistTensorRandomOpTest(DTensorTestBase):
@with_comms
@skip_unless_torch_gpu
def test_rng_tracker_init(self):
torch.manual_seed(self.rank)
object_list = [torch.initial_seed()]
broadcast_object_list(object_list)
seed_from_rank_0 = int(object_list[0])
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
# seed synchronization now does NOT happen after the first `distribute_tensor`
# call
dt = distribute_tensor(
torch.empty([self.world_size], device=self.device_type),
device_mesh,
[Shard(0)],
)
self.assertTrue(random._rng_tracker is None)
# seed synchronization only happens after `manual_seed` or the first DTensor
# random op call
dt.uniform_(0, 1)
# We do not maintain the copy of the seed in dtensor, but we do mutate the global rng state
# since we now always pull it fresh from the local device generator
self.assertEqual(
seed_from_rank_0, get_generator_seed_for_device_type(self.device_type)
)
@with_comms
@skip_unless_torch_gpu
def test_manual_seed(self):
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
# in the case of calling ``torch.distributed.tensor._random.manual_seed``,
# no seed synchronization should happen since we fully trust the users' input
# and will not override the value.
comm_mode = CommDebugMode()
with comm_mode:
# Test 1: set different seed on different ranks
# RNG tracker should not be initialized until DTensor ``manual_seed``
# is called.
self.assertTrue(random._rng_tracker is None)
manual_seed(self.rank, device_mesh)
# RNG tracker should already be initialized
self.assertTrue(random._rng_tracker is not None)
self.assertEqual(
self.rank, get_generator_seed_for_device_type(self.device_type)
)
# Test 2: set same seed on different ranks
manual_seed(1234, device_mesh)
self.assertEqual(1234, get_generator_seed_for_device_type(self.device_type))
self.assertEqual(comm_mode.get_total_counts(), 0)
@with_comms
@skip_unless_torch_gpu
def test_manual_seed_submesh(self):
# the current rank is not a part of the mesh
single_rank_device_mesh = DeviceMesh(
self.device_type, [(self.rank + 1) % self.world_size]
)
with self.assertRaisesRegex(
RuntimeError,
"manual_seed requires the current rank to be a part of the device mesh",
):
manual_seed(self.rank, single_rank_device_mesh)
@with_comms
@skip_unless_torch_gpu
def test_pipeline_parallel_manual_seed(self):
# This test is to verify the `manual_seed` API works as expected in the
# pipeline parallel setting.
world_mesh = init_device_mesh(
self.device_type,
(self.world_size // 2, 2),
mesh_dim_names=("pp", "spmd"),
)
pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank() # rank 0,1 = 0; rank 2,3 = 1
spmd_mesh = world_mesh["spmd"]
# set the seed for each pipeline stage to 123 + pp_rank
manual_seed(123 + pp_rank, spmd_mesh)
# dtensor no longer stores a copy of the seed, but it mutates the device's generator so we can check that
self.assertEqual(
123 + pp_rank, get_generator_seed_for_device_type(self.device_type)
)
# mimic initializing a model weight sharded on the SPMD mesh
spmd_dtensor = torch.distributed.tensor.ones(
2 * spmd_mesh.size(), 2, device_mesh=spmd_mesh, placements=[Shard(0)]
)
torch.nn.init.normal_(spmd_dtensor)
# gather all the shards to compare initialization results
WORLD = torch.distributed.group.WORLD
assert WORLD is not None
tensor_gather = funcol.all_gather_tensor(
spmd_dtensor.to_local(),
gather_dim=0,
group=WORLD,
)
# verify the weights are initialized differently on all ranks
for other_rank in range(self.world_size):
if self.rank != other_rank:
self.assertNotEqual(
spmd_dtensor.to_local(),
tensor_gather[2 * other_rank : 2 * (other_rank + 1), :],
)
@with_comms
@skip_unless_torch_gpu
def test_deterministic_dropout_1d(self):
# test suite sets each rank's seed to the same value but in actual
# execution the default random seed will be different (a random value).
# The DTensor random ops will use the same random seed even though the
# torch random generator keeps different seeds on ranks.
torch.manual_seed(self.rank)
# TODO: add test before/after enabling distribute region
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
size = [4, 4]
dtensor = distribute_tensor(
torch.empty(*size, device=self.device_type), device_mesh, [Shard(1)]
)
# a random op call shifts the offset
dtensor.uniform_(0, 1)
# the dtensor is now replicate on all ranks
dtensor = dtensor.redistribute(device_mesh, [Replicate()])
dropout = torch.nn.Dropout(p=0.2)
dtensor = dropout(dtensor)
# allgather the local tensors
local_tensor = funcol.all_gather_tensor(
dtensor.to_local(), gather_dim=0, group=(device_mesh, 0)
)
# compare with local tensors from other ranks
self_slice = slice(4 * self.rank, 4 * self.rank + 4)
for other_rank in range(self.world_size):
if self.rank != other_rank:
# other rank should have an identical local tensor
other_slice = slice(4 * other_rank, 4 * other_rank + 4)
self.assertEqual(
local_tensor[self_slice, :],
local_tensor[other_slice, :],
)
@with_comms
@skip_unless_torch_gpu
def test_deterministic_rand_1d(self):
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
size = [4, 4 * self.world_size]
for fn in [
torch.distributed.tensor.rand,
torch.distributed.tensor.randn,
]:
dtensor = fn(size, device_mesh=device_mesh, placements=[Shard(1)])
local_tensor = funcol.all_gather_tensor(
dtensor.to_local(), gather_dim=0, group=(device_mesh, 0)
)
# compare with local tensors from other ranks
self_slice = slice(4 * self.rank, 4 * self.rank + 4)
for other_rank in range(self.world_size):
if self.rank != other_rank:
# other rank should have a different local tensor for shard placement
other_slice = slice(4 * other_rank, 4 * other_rank + 4)
self.assertNotEqual(
local_tensor[self_slice, :],
local_tensor[other_slice, :],
)
# we should set manual seed to the same value on all SPMD ranks
torch.manual_seed(0)
dtensor = fn(size, device_mesh=device_mesh, placements=[Replicate()])
local_tensor = funcol.all_gather_tensor(
dtensor.to_local(), gather_dim=0, group=(device_mesh, 0)
)
# compare with local tensors from other ranks
self_slice = slice(4 * self.rank, 4 * self.rank + 4)
for other_rank in range(self.world_size):
if self.rank != other_rank:
# other rank should have an identical local tensor for replicate placement
other_slice = slice(4 * other_rank, 4 * other_rank + 4)
self.assertEqual(
local_tensor[self_slice, :],
local_tensor[other_slice, :],
)
@with_comms
@skip_if_lt_x_gpu(4)
def test_deterministic_uniform_2d(self):
mesh = torch.arange(self.world_size).reshape(2, 2)
device_mesh = DeviceMesh(self.device_type, mesh)
dtensor = distribute_tensor(
torch.empty(
*[self.world_size for _ in mesh.size()], device=self.device_type
),
device_mesh,
[Replicate(), Replicate()],
)
placements_list = [ # this list of placements should be enough to cover
[Shard(0), Shard(1)],
[Shard(1), Shard(0)],
[Shard(0), Replicate()],
[Replicate(), Shard(0)],
[Shard(1), Replicate()],
[Replicate(), Shard(1)],
[Replicate(), Replicate()],
]
shard_index_list = [
{0: 0, 1: 1, 2: 2, 3: 3},
{0: 0, 1: 2, 2: 1, 3: 3},
{0: 0, 1: 0, 2: 1, 3: 1},
{0: 0, 1: 1, 2: 0, 3: 1},
{0: 0, 1: 0, 2: 1, 3: 1},
{0: 0, 1: 1, 2: 0, 3: 1},
{0: 0, 1: 0, 2: 0, 3: 0},
]
coordinate = device_mesh.get_coordinate()
assert coordinate is not None
for placements, shard_index in zip(placements_list, shard_index_list):
dtensor = dtensor.redistribute(device_mesh, placements)
# random op call
dtensor.uniform_(0, 1)
# check shard information is correct
shard_coord = [
coordinate[mesh_dim] if mesh_dim >= 0 else 0
for mesh_dim in dtensor._spec.dim_map
]
shard_size = [
device_mesh.size(mesh_dim) if mesh_dim >= 0 else 1
for mesh_dim in dtensor._spec.dim_map
]
shard_linear_idx = random._rng_tracker._calc_shard_linear_idx(
shard_coord, shard_size
)
self.assertEqual(shard_linear_idx, shard_index[self.rank])
# compute local size and offset
_, local_shard_offset = compute_local_shape_and_global_offset(
dtensor.shape, device_mesh, placements
)
# get the local shard size and local shard offset for each shard
# local_shard_list_on_dim[i] has the list of all shards on that dim
# as a tuple (local_shard_offset, local_shard_size)
dtensor_shape = dtensor.shape
local_shard_list_on_dim: list[list[tuple[int, int]]] = [
[(0, l)] for l in dtensor_shape
]
for idx, placement in enumerate(placements):
if isinstance(placement, Shard):
mesh_dim_size = device_mesh.size(idx)
shard_dim = placement.dim
local_shard_list_on_dim[shard_dim] = []
for shard_idx_on_dim in range(mesh_dim_size):
(
shard_size,
shard_offset,
) = placement._local_shard_size_and_offset(
dtensor_shape[shard_dim],
mesh_dim_size,
shard_idx_on_dim,
)
local_shard_list_on_dim[shard_dim].append(
(not_none(shard_offset), shard_size)
)
local_shard_comb = itertools.product(*local_shard_list_on_dim)
# the local shard
local_tensor = dtensor.to_local()
# allgather the local tensors
full_tensor = dtensor.full_tensor()
# compare local tensor with each other shard
for other_local_shard in local_shard_comb:
other_local_shard_offset, _ = zip(*other_local_shard)
slice_idx = [
slice(offset, offset + size) for offset, size in other_local_shard
]
if local_shard_offset == other_local_shard_offset:
self.assertEqual(full_tensor[tuple(slice_idx)], local_tensor)
else:
self.assertNotEqual(full_tensor[tuple(slice_idx)], local_tensor)
class DistTensorRandomOpsTest3D(DTensorTestBase):
@property
def world_size(self):
return 8
@skip_if_lt_x_gpu(8)
@with_comms
def test_hsdp_tp_model_meta_init(self):
# initialize the 3-d device mesh
global_mesh = init_device_mesh(
self.device_type,
mesh_shape=(self.world_size // 4, 2, 2),
mesh_dim_names=("dp_replicate", "dp_shard", "tp"),
)
tp_mesh = global_mesh["tp"]
dp_mesh = global_mesh["dp_replicate", "dp_shard"]
# model meta init
with torch.device("meta"):
model = torch.nn.Linear(self.world_size, self.world_size, bias=False)
self.assertEqual(model.weight.device, torch.device("meta"))
parallelize_module(model, tp_mesh, ColwiseParallel())
if random._rng_tracker is not None:
random._rng_tracker.distribute_region_enabled = True
fully_shard(model, mesh=dp_mesh)
self.assertEqual(model.weight.device, torch.device("meta"))
# actual initialization
device = torch.device(
self.device_type, torch.get_device_module(self.device_type).current_device()
)
model.to_empty(device=device)
model.reset_parameters()
self.assertTrue(
random._rng_tracker is not None
and isinstance(random._rng_tracker, OffsetBasedRNGTracker)
)
self.assertEqual(model.weight.device, device)
assert isinstance(model.weight, DTensor)
# gather all the shards to compare initialization results
WORLD = torch.distributed.group.WORLD
assert WORLD is not None
weight_local = model.weight.to_local()
weight_gather = funcol.all_gather_tensor(
weight_local,
gather_dim=0,
group=WORLD,
)
# verify the weights are initialized differently on all ranks
shard_dim_0_len = self.world_size // 4
for other_rank in range(self.world_size):
other_rank_dim_0_start = other_rank * shard_dim_0_len
other_rank_dim_0_end = other_rank_dim_0_start + shard_dim_0_len
if self.rank % 4 != other_rank % 4:
self.assertNotEqual(
weight_local,
weight_gather[other_rank_dim_0_start:other_rank_dim_0_end, :],
)
else:
self.assertEqual(
weight_local,
weight_gather[other_rank_dim_0_start:other_rank_dim_0_end, :],
)
if __name__ == "__main__":
run_tests()