mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[DTensor] Make default RNG semantics match user-passed generator (#160482)
Previously, DTensor kept its own copy of the generator state after the first time a random operator was called on a DTensor. This copy would evolve independently from the generator outside of DTensor. After adding support for users to pass a specific generator into random operators (e.g. `uniform_(..., generator=)`), it was determined (in discussion on #159991) to change the semantics so that any random operations performed on DTensor would evolve the state of the publicly visible generators (either the default one or user-passed one). The upsides are (1) it is now possible to call torch.manual_seed() at any point in the program and have a consistent effect on DTensor, (2) DTensor ops have an observable effect on the generator. The downside is that users are now responsible for seeding their generator before using DTensor, ensuring all ranks use the same seed. Fixes #159991 confirmed docs rendered OK <img width="897" height="414" alt="image" src="https://github.com/user-attachments/assets/c082f0f0-5447-47aa-834f-65342eb237cd" /> Pull Request resolved: https://github.com/pytorch/pytorch/pull/160482 Approved by: https://github.com/wanchaol
This commit is contained in:
parent
cc2b65a91a
commit
d1faf2ef04
|
|
@ -179,6 +179,18 @@ specifying the {class}`DeviceMesh` and {class}`Placement` for the {class}`DTenso
|
|||
|
||||
```
|
||||
|
||||
### Random Operations
|
||||
|
||||
DTensor provides distributed RNG functionality to ensure that random operations on sharded tensors get unique values, and random operations on replicated tensors get the same values. This system requires that all participating
|
||||
ranks (e.g. SPMD ranks) start out using the same generator state before each dtensor random operation is performed,
|
||||
and if this is true, it ensures they all end up at the same state after each dtensor random operation completes. There is no communication performed during random operations to synchronize RNG states.
|
||||
|
||||
Operators that accept a `generator` kwarg will utilize the user-passed generator, if passed, or the default generator for the device otherwise. Whichever generator is used, it will be advanced after the DTensor operation. It is valid to use the same generator for both DTensor and non-DTensor operations, but care must be taken to ensure the non-DTensor operations advance the generator state equally on all ranks if so.
|
||||
|
||||
When using DTensor together with Pipeline Parallelism, ranks for each pipeline stage should use a distinct seed, and ranks within a pipeline stage should use the same seed.
|
||||
|
||||
DTensor's RNG infra is based on the philox based RNG algorithm, and supports any philox based backend (cuda, and other cuda-like devices), but unfortunately does not yet support the CPU backend.
|
||||
|
||||
## Debugging
|
||||
|
||||
```{eval-rst}
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ class TensorParallelRandomStateTests(DTensorTestBase):
|
|||
# in the following way:
|
||||
# - within a tensor parallel group, the RNG is set with the same seed
|
||||
# - across data parallel groups, the RNG is set with different seeds
|
||||
torch.cuda.manual_seed(dp_rank)
|
||||
torch.cuda.manual_seed(0)
|
||||
|
||||
# disable/enable parallel RNG feature
|
||||
if random._rng_tracker:
|
||||
|
|
@ -118,14 +118,10 @@ class TensorParallelRandomStateTests(DTensorTestBase):
|
|||
|
||||
# compare local shards across TP groups
|
||||
def dp_weights_assert(tensor1, tensor2):
|
||||
if enable_distribute_flag:
|
||||
# local weights shall be initialized the same across TP groups
|
||||
self.assertEqual(tensor1, tensor2)
|
||||
else:
|
||||
# without the parallel RNG, weight initialization violates the TP setup:
|
||||
# local weights are initialized differently across TP groups due to different
|
||||
# random seeds set in data loading.
|
||||
self.assertNotEqual(tensor1, tensor2)
|
||||
# local weights shall be initialized the same across TP groups,
|
||||
# and it doesn't matter whether DTensor's RNG infra is activated since all spmd ranks
|
||||
# started with the same seed.
|
||||
self.assertEqual(tensor1, tensor2)
|
||||
|
||||
self.check_gathered_tensors(
|
||||
dp_rank, dp_size, tensor_gather, dp_weights_assert
|
||||
|
|
|
|||
|
|
@ -33,6 +33,11 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|||
)
|
||||
|
||||
|
||||
def get_generator_seed_for_device_type(device_type: str) -> int:
|
||||
device_module = torch.get_device_module(device_type)
|
||||
return device_module.get_rng_state()[:8].view(torch.int64).item()
|
||||
|
||||
|
||||
class DistTensorRandomInitTest(DTensorTestBase):
|
||||
def _run_init_op(self, init_op, *args, **kwargs):
|
||||
device_mesh = self.build_device_mesh()
|
||||
|
|
@ -113,22 +118,20 @@ class DistTensorRandomInitTest(DTensorTestBase):
|
|||
# (`torch.distributed.tensor._random._rng_tracker._manual_seed`)
|
||||
# (b) If we try to match the semantics of (a) with a user-supplied RNG, they may be very surprised to find that
|
||||
# their RNG object never advances its state after using it with DTensor.
|
||||
# torch.distributed.tensor._random._rng_tracker._manual_seed(55)
|
||||
# rng.manual_seed(55)
|
||||
# torch.nn.init.uniform_(t1, 0.0, 1.0)
|
||||
# torch.nn.init.uniform_(t2, 0.0, 1.0, rng)
|
||||
# self.assertEqual(t1.full_tensor(), t2.full_tensor())
|
||||
torch.manual_seed(55)
|
||||
rng.manual_seed(55)
|
||||
torch.nn.init.uniform_(t1, 0.0, 1.0)
|
||||
torch.nn.init.uniform_(t2, 0.0, 1.0, rng)
|
||||
self.assertEqual(t1.full_tensor(), t2.full_tensor())
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_meta_tensor_init(self):
|
||||
# test suite sets each rank's seed to the same value but in actual
|
||||
# execution the default random seed will be different (a random value).
|
||||
# The DTensor random ops will use the same random seed even though the
|
||||
# torch random generator keeps different seeds on ranks. This ensures
|
||||
# that Replicate DTensor will have the same initialized results
|
||||
# across ranks.
|
||||
torch.cuda.manual_seed(self.rank)
|
||||
# test suite sets each rank's seed to the same value.
|
||||
# The DTensor random ops will use the same generator as the default one on the device.
|
||||
|
||||
# Note: this behavior changed, and now the guideline is to set the same RNG seed on all SPMD ranks.
|
||||
torch.cuda.manual_seed(0)
|
||||
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
size = [1024, 2048]
|
||||
meta_dtensor = distribute_tensor(
|
||||
|
|
@ -147,7 +150,7 @@ class DistTensorRandomInitTest(DTensorTestBase):
|
|||
self.assertTrue(random._rng_tracker.distribute_region_enabled)
|
||||
|
||||
# allgather the local tensors
|
||||
local_tensor = funcol.all_gather_tensor(
|
||||
gathered_local_tensors = funcol.all_gather_tensor(
|
||||
dtensor.to_local(), gather_dim=0, group=(device_mesh, 0)
|
||||
)
|
||||
|
||||
|
|
@ -158,7 +161,8 @@ class DistTensorRandomInitTest(DTensorTestBase):
|
|||
# other rank should have an identical local tensor
|
||||
other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024)
|
||||
self.assertEqual(
|
||||
local_tensor[self_slice, :], local_tensor[other_slice, :]
|
||||
gathered_local_tensors[self_slice, :],
|
||||
gathered_local_tensors[other_slice, :],
|
||||
)
|
||||
|
||||
# Test 2: disable the distribute region for RNG
|
||||
|
|
@ -177,11 +181,11 @@ class DistTensorRandomInitTest(DTensorTestBase):
|
|||
|
||||
# compare with local tensors from other ranks
|
||||
for other_rank in range(self.world_size):
|
||||
# the RNG result on each rank differs even they're supposed
|
||||
# to be replicated
|
||||
# the RNG result on each rank are the same even without the help of DTensor's RNG infra,
|
||||
# since the default RNG is the same across ranks.
|
||||
if self.rank != other_rank:
|
||||
other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024)
|
||||
self.assertNotEqual(
|
||||
self.assertEqual(
|
||||
local_tensor[self_slice, :], local_tensor[other_slice, :]
|
||||
)
|
||||
|
||||
|
|
@ -307,7 +311,12 @@ class DistTensorRandomOpTest(DTensorTestBase):
|
|||
# seed synchronization only happens after `manual_seed` or the first DTensor
|
||||
# random op call
|
||||
dt.uniform_(0, 1)
|
||||
self.assertEqual(seed_from_rank_0, random._rng_tracker.get_seed("parallel-rng"))
|
||||
|
||||
# We do not maintain the copy of the seed in dtensor, but we do mutate the global rng state
|
||||
# since we now always pull it fresh from the local device generator
|
||||
self.assertEqual(
|
||||
seed_from_rank_0, get_generator_seed_for_device_type(self.device_type)
|
||||
)
|
||||
|
||||
@with_comms
|
||||
@skip_unless_torch_gpu
|
||||
|
|
@ -326,11 +335,13 @@ class DistTensorRandomOpTest(DTensorTestBase):
|
|||
manual_seed(self.rank, device_mesh)
|
||||
# RNG tracker should already be initialized
|
||||
self.assertTrue(random._rng_tracker is not None)
|
||||
self.assertEqual(self.rank, random._rng_tracker.get_seed("parallel-rng"))
|
||||
self.assertEqual(
|
||||
self.rank, get_generator_seed_for_device_type(self.device_type)
|
||||
)
|
||||
|
||||
# Test 2: set same seed on different ranks
|
||||
manual_seed(1234, device_mesh)
|
||||
self.assertEqual(1234, random._rng_tracker.get_seed("parallel-rng"))
|
||||
self.assertEqual(1234, get_generator_seed_for_device_type(self.device_type))
|
||||
|
||||
self.assertEqual(comm_mode.get_total_counts(), 0)
|
||||
|
||||
|
|
@ -363,7 +374,10 @@ class DistTensorRandomOpTest(DTensorTestBase):
|
|||
|
||||
# set the seed for each pipeline stage to 123 + pp_rank
|
||||
manual_seed(123 + pp_rank, spmd_mesh)
|
||||
self.assertEqual(123 + pp_rank, random._rng_tracker.get_seed("parallel-rng"))
|
||||
# dtensor no longer stores a copy of the seed, but it mutates the device's generator so we can check that
|
||||
self.assertEqual(
|
||||
123 + pp_rank, get_generator_seed_for_device_type(self.device_type)
|
||||
)
|
||||
|
||||
# mimic initializing a model weight sharded on the SPMD mesh
|
||||
spmd_dtensor = torch.distributed.tensor.ones(
|
||||
|
|
@ -448,14 +462,15 @@ class DistTensorRandomOpTest(DTensorTestBase):
|
|||
self_slice = slice(4 * self.rank, 4 * self.rank + 4)
|
||||
for other_rank in range(self.world_size):
|
||||
if self.rank != other_rank:
|
||||
# other rank should have an identical local tensor
|
||||
# other rank should have a different local tensor for shard placement
|
||||
other_slice = slice(4 * other_rank, 4 * other_rank + 4)
|
||||
self.assertNotEqual(
|
||||
local_tensor[self_slice, :],
|
||||
local_tensor[other_slice, :],
|
||||
)
|
||||
|
||||
torch.manual_seed(self.rank)
|
||||
# we should set manual seed to the same value on all SPMD ranks
|
||||
torch.manual_seed(0)
|
||||
dtensor = fn(size, device_mesh=device_mesh, placements=[Replicate()])
|
||||
local_tensor = funcol.all_gather_tensor(
|
||||
dtensor.to_local(), gather_dim=0, group=(device_mesh, 0)
|
||||
|
|
@ -465,7 +480,7 @@ class DistTensorRandomOpTest(DTensorTestBase):
|
|||
self_slice = slice(4 * self.rank, 4 * self.rank + 4)
|
||||
for other_rank in range(self.world_size):
|
||||
if self.rank != other_rank:
|
||||
# other rank should have an identical local tensor
|
||||
# other rank should have an identical local tensor for replicate placement
|
||||
other_slice = slice(4 * other_rank, 4 * other_rank + 4)
|
||||
self.assertEqual(
|
||||
local_tensor[self_slice, :],
|
||||
|
|
|
|||
|
|
@ -2,16 +2,18 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import contextlib
|
||||
import warnings
|
||||
from logging import getLogger
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from torch.distributed.device_mesh import _get_device_handle, DeviceMesh
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
||||
from torch.distributed.tensor.placement_types import Shard
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"is_rng_supported_mesh",
|
||||
"manual_seed",
|
||||
|
|
@ -75,22 +77,31 @@ def manual_seed(seed: int, device_mesh: DeviceMesh) -> None:
|
|||
)
|
||||
return
|
||||
|
||||
# TODO: deprecate this API, but also need to ensure we disable broadcast for PP case, and that's currently
|
||||
# bundled together with this API. See torchtitan/distributed/utils.py:set_determinism
|
||||
# warnings.warn(
|
||||
# "DTensor manual_seed() is deprecated, since DTensor no longer maintains a separate copy of generator state. "
|
||||
# "Use `torch.manual_seed` instead"
|
||||
# )
|
||||
# Note: we still need to ensure setting `run_state_sync=False` to support the the pp case
|
||||
|
||||
# instantiate a RNG tracker if haven't. By default DTensor uses an
|
||||
# OffsetBasedRNGTracker to perform random operators.
|
||||
global _rng_tracker
|
||||
if not _rng_tracker:
|
||||
_rng_tracker = OffsetBasedRNGTracker(device_mesh, run_state_sync=False)
|
||||
|
||||
# the current rank is in mesh
|
||||
if device_mesh.get_coordinate() is not None:
|
||||
_rng_tracker._manual_seed(seed)
|
||||
else:
|
||||
if device_mesh.get_coordinate() is None:
|
||||
raise RuntimeError(
|
||||
"manual_seed requires the current rank to be a part of the device mesh "
|
||||
"otherwise DTensor RNG state on the rank will not be initialized and "
|
||||
"the behavior of DTensor random ops is undefined."
|
||||
)
|
||||
|
||||
# DTensor no longer maintains a copy of rng state. manual seed on dtensor is the same thing
|
||||
# as manual seed on torch.
|
||||
torch.manual_seed(seed)
|
||||
|
||||
|
||||
class _RNGStateTracker:
|
||||
"""
|
||||
|
|
@ -178,16 +189,38 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
|
|||
f"CUDA/CUDA-like/XPU device. Got {self._device.type} instead."
|
||||
)
|
||||
|
||||
rng_state = self._get_device_state()
|
||||
if run_state_sync:
|
||||
# synchronize RNG state using rank 0's current one
|
||||
torch.distributed.broadcast(rng_state, 0)
|
||||
my_rng_state = self._get_device_state()
|
||||
if not all(my_rng_state == rng_state):
|
||||
logger.warning(
|
||||
"DTensor is synchronizing RNG states of every rank with the state from rank 0. "
|
||||
"This behavior is deprecated. "
|
||||
"Please call `torch.manual_seed()` on every rank that participates in SPMD DTensor Operations with "
|
||||
"the same seed. If using Pipeline Parallelism, each pipeling state would use a different seed, "
|
||||
"but all ranks belonging to one pipeline stage would use the same seed."
|
||||
)
|
||||
self._set_device_state(rng_state)
|
||||
|
||||
def _get_device_state(self) -> torch.Tensor:
|
||||
if self._device.type == "hpu":
|
||||
self._device_handle.set_rng_ctx("philox")
|
||||
rng_state = self._device_handle.get_rng_state().to(self._device)
|
||||
if self._device.type == "hpu":
|
||||
self._device_handle.unset_rng_ctx("philox")
|
||||
if run_state_sync:
|
||||
# synchronize RNG state using rank 0's current one
|
||||
dist.broadcast(rng_state, 0)
|
||||
return rng_state
|
||||
|
||||
self.rng_states["parallel-rng"] = rng_state.to("cpu")
|
||||
def _set_device_state(self, state: torch.Tensor):
|
||||
# It seems that the underlying generator wants a cpu tensor but the dtensor code expects `_get_device_state`
|
||||
# to convert to a 'device' tensor, probably because we may use it with our backend comms for sync/debug
|
||||
# for now, we just convert back to cpu here to make sure it always works.
|
||||
if self._device.type == "hpu":
|
||||
self._device_handle.set_rng_ctx("philox")
|
||||
self._device_handle.set_rng_state(state.to("cpu"))
|
||||
if self._device.type == "hpu":
|
||||
self._device_handle.unset_rng_ctx("philox")
|
||||
|
||||
def _manual_seed(self, parallel_seed: int) -> None:
|
||||
self.set_seed("parallel-rng", parallel_seed)
|
||||
|
|
@ -196,7 +229,6 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
|
|||
def _distribute_region(
|
||||
self, spec: DTensorSpec, generator: Optional[torch.Generator] = None
|
||||
):
|
||||
g_name = "parallel-rng"
|
||||
if generator is not None:
|
||||
# This is a little hacky, but for any user-passed generator, we store its state under a unique key,
|
||||
# not because we need to keep a copy of it but because its the easiest way to make it work with the
|
||||
|
|
@ -204,12 +236,10 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
|
|||
g_name = "user-passed-generator"
|
||||
assert g_name not in self.rng_states
|
||||
self.rng_states[g_name] = generator.get_state()
|
||||
# check if the parallel rng state has been synchronized or not
|
||||
if not self.rng_state_is_sync("parallel-rng"):
|
||||
raise RuntimeError(
|
||||
"OffsetBasedRNGTracker requires the random state to be synchronized "
|
||||
"before entering into a distribute region!"
|
||||
)
|
||||
else:
|
||||
g_name = "parallel-rng"
|
||||
assert g_name not in self.rng_states
|
||||
self.rng_states[g_name] = self._get_device_state().to("cpu")
|
||||
|
||||
if self.distribute_region_enabled:
|
||||
if self._device.type == "hpu":
|
||||
|
|
@ -236,6 +266,8 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
|
|||
# usage of that RNG (dtensor or non-dtensor), (b) drop it from our own cache so that if the user updates
|
||||
# the seed value in their rng and uses it with DTensor again, we always use the latest value
|
||||
generator.set_state(self.rng_states.pop(g_name))
|
||||
else:
|
||||
self._set_device_state(self.rng_states.pop(g_name))
|
||||
|
||||
def get_offset(self, name: str) -> int:
|
||||
if name not in self.rng_states:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user