[DTensor] Make default RNG semantics match user-passed generator (#160482)

Previously, DTensor kept its own copy of the generator state after the
first time a random operator was called on a DTensor. This copy would
evolve independently from the generator outside of DTensor.

After adding support for users to pass a specific generator into
random operators (e.g. `uniform_(..., generator=)`), it was determined
(in discussion on #159991) to change the semantics so that any random
operations performed on DTensor would evolve the state of the publicly
visible generators (either the default one or user-passed one).

The upsides are (1) it is now possible to call torch.manual_seed() at
any point in the program and have a consistent effect on DTensor, (2)
DTensor ops have an observable effect on the generator.  The downside is
that users are now responsible for seeding their generator before using
DTensor, ensuring all ranks use the same seed.

Fixes #159991

confirmed docs rendered OK

<img width="897" height="414" alt="image" src="https://github.com/user-attachments/assets/c082f0f0-5447-47aa-834f-65342eb237cd" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160482
Approved by: https://github.com/wanchaol
This commit is contained in:
Will Constable 2025-08-20 17:25:22 -07:00 committed by PyTorch MergeBot
parent cc2b65a91a
commit d1faf2ef04
4 changed files with 104 additions and 49 deletions

View File

@ -179,6 +179,18 @@ specifying the {class}`DeviceMesh` and {class}`Placement` for the {class}`DTenso
```
### Random Operations
DTensor provides distributed RNG functionality to ensure that random operations on sharded tensors get unique values, and random operations on replicated tensors get the same values. This system requires that all participating
ranks (e.g. SPMD ranks) start out using the same generator state before each dtensor random operation is performed,
and if this is true, it ensures they all end up at the same state after each dtensor random operation completes. There is no communication performed during random operations to synchronize RNG states.
Operators that accept a `generator` kwarg will utilize the user-passed generator, if passed, or the default generator for the device otherwise. Whichever generator is used, it will be advanced after the DTensor operation. It is valid to use the same generator for both DTensor and non-DTensor operations, but care must be taken to ensure the non-DTensor operations advance the generator state equally on all ranks if so.
When using DTensor together with Pipeline Parallelism, ranks for each pipeline stage should use a distinct seed, and ranks within a pipeline stage should use the same seed.
DTensor's RNG infra is based on the philox based RNG algorithm, and supports any philox based backend (cuda, and other cuda-like devices), but unfortunately does not yet support the CPU backend.
## Debugging
```{eval-rst}

View File

@ -66,7 +66,7 @@ class TensorParallelRandomStateTests(DTensorTestBase):
# in the following way:
# - within a tensor parallel group, the RNG is set with the same seed
# - across data parallel groups, the RNG is set with different seeds
torch.cuda.manual_seed(dp_rank)
torch.cuda.manual_seed(0)
# disable/enable parallel RNG feature
if random._rng_tracker:
@ -118,14 +118,10 @@ class TensorParallelRandomStateTests(DTensorTestBase):
# compare local shards across TP groups
def dp_weights_assert(tensor1, tensor2):
if enable_distribute_flag:
# local weights shall be initialized the same across TP groups
self.assertEqual(tensor1, tensor2)
else:
# without the parallel RNG, weight initialization violates the TP setup:
# local weights are initialized differently across TP groups due to different
# random seeds set in data loading.
self.assertNotEqual(tensor1, tensor2)
# local weights shall be initialized the same across TP groups,
# and it doesn't matter whether DTensor's RNG infra is activated since all spmd ranks
# started with the same seed.
self.assertEqual(tensor1, tensor2)
self.check_gathered_tensors(
dp_rank, dp_size, tensor_gather, dp_weights_assert

View File

@ -33,6 +33,11 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
)
def get_generator_seed_for_device_type(device_type: str) -> int:
device_module = torch.get_device_module(device_type)
return device_module.get_rng_state()[:8].view(torch.int64).item()
class DistTensorRandomInitTest(DTensorTestBase):
def _run_init_op(self, init_op, *args, **kwargs):
device_mesh = self.build_device_mesh()
@ -113,22 +118,20 @@ class DistTensorRandomInitTest(DTensorTestBase):
# (`torch.distributed.tensor._random._rng_tracker._manual_seed`)
# (b) If we try to match the semantics of (a) with a user-supplied RNG, they may be very surprised to find that
# their RNG object never advances its state after using it with DTensor.
# torch.distributed.tensor._random._rng_tracker._manual_seed(55)
# rng.manual_seed(55)
# torch.nn.init.uniform_(t1, 0.0, 1.0)
# torch.nn.init.uniform_(t2, 0.0, 1.0, rng)
# self.assertEqual(t1.full_tensor(), t2.full_tensor())
torch.manual_seed(55)
rng.manual_seed(55)
torch.nn.init.uniform_(t1, 0.0, 1.0)
torch.nn.init.uniform_(t2, 0.0, 1.0, rng)
self.assertEqual(t1.full_tensor(), t2.full_tensor())
@with_comms
@skip_if_lt_x_gpu(4)
def test_meta_tensor_init(self):
# test suite sets each rank's seed to the same value but in actual
# execution the default random seed will be different (a random value).
# The DTensor random ops will use the same random seed even though the
# torch random generator keeps different seeds on ranks. This ensures
# that Replicate DTensor will have the same initialized results
# across ranks.
torch.cuda.manual_seed(self.rank)
# test suite sets each rank's seed to the same value.
# The DTensor random ops will use the same generator as the default one on the device.
# Note: this behavior changed, and now the guideline is to set the same RNG seed on all SPMD ranks.
torch.cuda.manual_seed(0)
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
size = [1024, 2048]
meta_dtensor = distribute_tensor(
@ -147,7 +150,7 @@ class DistTensorRandomInitTest(DTensorTestBase):
self.assertTrue(random._rng_tracker.distribute_region_enabled)
# allgather the local tensors
local_tensor = funcol.all_gather_tensor(
gathered_local_tensors = funcol.all_gather_tensor(
dtensor.to_local(), gather_dim=0, group=(device_mesh, 0)
)
@ -158,7 +161,8 @@ class DistTensorRandomInitTest(DTensorTestBase):
# other rank should have an identical local tensor
other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024)
self.assertEqual(
local_tensor[self_slice, :], local_tensor[other_slice, :]
gathered_local_tensors[self_slice, :],
gathered_local_tensors[other_slice, :],
)
# Test 2: disable the distribute region for RNG
@ -177,11 +181,11 @@ class DistTensorRandomInitTest(DTensorTestBase):
# compare with local tensors from other ranks
for other_rank in range(self.world_size):
# the RNG result on each rank differs even they're supposed
# to be replicated
# the RNG result on each rank are the same even without the help of DTensor's RNG infra,
# since the default RNG is the same across ranks.
if self.rank != other_rank:
other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024)
self.assertNotEqual(
self.assertEqual(
local_tensor[self_slice, :], local_tensor[other_slice, :]
)
@ -307,7 +311,12 @@ class DistTensorRandomOpTest(DTensorTestBase):
# seed synchronization only happens after `manual_seed` or the first DTensor
# random op call
dt.uniform_(0, 1)
self.assertEqual(seed_from_rank_0, random._rng_tracker.get_seed("parallel-rng"))
# We do not maintain the copy of the seed in dtensor, but we do mutate the global rng state
# since we now always pull it fresh from the local device generator
self.assertEqual(
seed_from_rank_0, get_generator_seed_for_device_type(self.device_type)
)
@with_comms
@skip_unless_torch_gpu
@ -326,11 +335,13 @@ class DistTensorRandomOpTest(DTensorTestBase):
manual_seed(self.rank, device_mesh)
# RNG tracker should already be initialized
self.assertTrue(random._rng_tracker is not None)
self.assertEqual(self.rank, random._rng_tracker.get_seed("parallel-rng"))
self.assertEqual(
self.rank, get_generator_seed_for_device_type(self.device_type)
)
# Test 2: set same seed on different ranks
manual_seed(1234, device_mesh)
self.assertEqual(1234, random._rng_tracker.get_seed("parallel-rng"))
self.assertEqual(1234, get_generator_seed_for_device_type(self.device_type))
self.assertEqual(comm_mode.get_total_counts(), 0)
@ -363,7 +374,10 @@ class DistTensorRandomOpTest(DTensorTestBase):
# set the seed for each pipeline stage to 123 + pp_rank
manual_seed(123 + pp_rank, spmd_mesh)
self.assertEqual(123 + pp_rank, random._rng_tracker.get_seed("parallel-rng"))
# dtensor no longer stores a copy of the seed, but it mutates the device's generator so we can check that
self.assertEqual(
123 + pp_rank, get_generator_seed_for_device_type(self.device_type)
)
# mimic initializing a model weight sharded on the SPMD mesh
spmd_dtensor = torch.distributed.tensor.ones(
@ -448,14 +462,15 @@ class DistTensorRandomOpTest(DTensorTestBase):
self_slice = slice(4 * self.rank, 4 * self.rank + 4)
for other_rank in range(self.world_size):
if self.rank != other_rank:
# other rank should have an identical local tensor
# other rank should have a different local tensor for shard placement
other_slice = slice(4 * other_rank, 4 * other_rank + 4)
self.assertNotEqual(
local_tensor[self_slice, :],
local_tensor[other_slice, :],
)
torch.manual_seed(self.rank)
# we should set manual seed to the same value on all SPMD ranks
torch.manual_seed(0)
dtensor = fn(size, device_mesh=device_mesh, placements=[Replicate()])
local_tensor = funcol.all_gather_tensor(
dtensor.to_local(), gather_dim=0, group=(device_mesh, 0)
@ -465,7 +480,7 @@ class DistTensorRandomOpTest(DTensorTestBase):
self_slice = slice(4 * self.rank, 4 * self.rank + 4)
for other_rank in range(self.world_size):
if self.rank != other_rank:
# other rank should have an identical local tensor
# other rank should have an identical local tensor for replicate placement
other_slice = slice(4 * other_rank, 4 * other_rank + 4)
self.assertEqual(
local_tensor[self_slice, :],

View File

@ -2,16 +2,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import contextlib
import warnings
from logging import getLogger
from typing import Optional, Union
import torch
import torch.distributed as dist
from torch import Tensor
from torch.distributed.device_mesh import _get_device_handle, DeviceMesh
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor.placement_types import Shard
logger = getLogger(__name__)
__all__ = [
"is_rng_supported_mesh",
"manual_seed",
@ -75,22 +77,31 @@ def manual_seed(seed: int, device_mesh: DeviceMesh) -> None:
)
return
# TODO: deprecate this API, but also need to ensure we disable broadcast for PP case, and that's currently
# bundled together with this API. See torchtitan/distributed/utils.py:set_determinism
# warnings.warn(
# "DTensor manual_seed() is deprecated, since DTensor no longer maintains a separate copy of generator state. "
# "Use `torch.manual_seed` instead"
# )
# Note: we still need to ensure setting `run_state_sync=False` to support the the pp case
# instantiate a RNG tracker if haven't. By default DTensor uses an
# OffsetBasedRNGTracker to perform random operators.
global _rng_tracker
if not _rng_tracker:
_rng_tracker = OffsetBasedRNGTracker(device_mesh, run_state_sync=False)
# the current rank is in mesh
if device_mesh.get_coordinate() is not None:
_rng_tracker._manual_seed(seed)
else:
if device_mesh.get_coordinate() is None:
raise RuntimeError(
"manual_seed requires the current rank to be a part of the device mesh "
"otherwise DTensor RNG state on the rank will not be initialized and "
"the behavior of DTensor random ops is undefined."
)
# DTensor no longer maintains a copy of rng state. manual seed on dtensor is the same thing
# as manual seed on torch.
torch.manual_seed(seed)
class _RNGStateTracker:
"""
@ -178,16 +189,38 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
f"CUDA/CUDA-like/XPU device. Got {self._device.type} instead."
)
rng_state = self._get_device_state()
if run_state_sync:
# synchronize RNG state using rank 0's current one
torch.distributed.broadcast(rng_state, 0)
my_rng_state = self._get_device_state()
if not all(my_rng_state == rng_state):
logger.warning(
"DTensor is synchronizing RNG states of every rank with the state from rank 0. "
"This behavior is deprecated. "
"Please call `torch.manual_seed()` on every rank that participates in SPMD DTensor Operations with "
"the same seed. If using Pipeline Parallelism, each pipeling state would use a different seed, "
"but all ranks belonging to one pipeline stage would use the same seed."
)
self._set_device_state(rng_state)
def _get_device_state(self) -> torch.Tensor:
if self._device.type == "hpu":
self._device_handle.set_rng_ctx("philox")
rng_state = self._device_handle.get_rng_state().to(self._device)
if self._device.type == "hpu":
self._device_handle.unset_rng_ctx("philox")
if run_state_sync:
# synchronize RNG state using rank 0's current one
dist.broadcast(rng_state, 0)
return rng_state
self.rng_states["parallel-rng"] = rng_state.to("cpu")
def _set_device_state(self, state: torch.Tensor):
# It seems that the underlying generator wants a cpu tensor but the dtensor code expects `_get_device_state`
# to convert to a 'device' tensor, probably because we may use it with our backend comms for sync/debug
# for now, we just convert back to cpu here to make sure it always works.
if self._device.type == "hpu":
self._device_handle.set_rng_ctx("philox")
self._device_handle.set_rng_state(state.to("cpu"))
if self._device.type == "hpu":
self._device_handle.unset_rng_ctx("philox")
def _manual_seed(self, parallel_seed: int) -> None:
self.set_seed("parallel-rng", parallel_seed)
@ -196,7 +229,6 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
def _distribute_region(
self, spec: DTensorSpec, generator: Optional[torch.Generator] = None
):
g_name = "parallel-rng"
if generator is not None:
# This is a little hacky, but for any user-passed generator, we store its state under a unique key,
# not because we need to keep a copy of it but because its the easiest way to make it work with the
@ -204,12 +236,10 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
g_name = "user-passed-generator"
assert g_name not in self.rng_states
self.rng_states[g_name] = generator.get_state()
# check if the parallel rng state has been synchronized or not
if not self.rng_state_is_sync("parallel-rng"):
raise RuntimeError(
"OffsetBasedRNGTracker requires the random state to be synchronized "
"before entering into a distribute region!"
)
else:
g_name = "parallel-rng"
assert g_name not in self.rng_states
self.rng_states[g_name] = self._get_device_state().to("cpu")
if self.distribute_region_enabled:
if self._device.type == "hpu":
@ -236,6 +266,8 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
# usage of that RNG (dtensor or non-dtensor), (b) drop it from our own cache so that if the user updates
# the seed value in their rng and uses it with DTensor again, we always use the latest value
generator.set_state(self.rng_states.pop(g_name))
else:
self._set_device_state(self.rng_states.pop(g_name))
def get_offset(self, name: str) -> int:
if name not in self.rng_states: