mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[DTensor] Make default RNG semantics match user-passed generator (#160482)"
This reverts commit d1faf2ef04.
Reverted https://github.com/pytorch/pytorch/pull/160482 on behalf of https://github.com/jeffdaily due to failing cuda and rocm jobs ([comment](https://github.com/pytorch/pytorch/pull/160482#issuecomment-3214694297))
This commit is contained in:
parent
ce467df5d1
commit
c7a77470c5
|
|
@ -179,18 +179,6 @@ specifying the {class}`DeviceMesh` and {class}`Placement` for the {class}`DTenso
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Random Operations
|
|
||||||
|
|
||||||
DTensor provides distributed RNG functionality to ensure that random operations on sharded tensors get unique values, and random operations on replicated tensors get the same values. This system requires that all participating
|
|
||||||
ranks (e.g. SPMD ranks) start out using the same generator state before each dtensor random operation is performed,
|
|
||||||
and if this is true, it ensures they all end up at the same state after each dtensor random operation completes. There is no communication performed during random operations to synchronize RNG states.
|
|
||||||
|
|
||||||
Operators that accept a `generator` kwarg will utilize the user-passed generator, if passed, or the default generator for the device otherwise. Whichever generator is used, it will be advanced after the DTensor operation. It is valid to use the same generator for both DTensor and non-DTensor operations, but care must be taken to ensure the non-DTensor operations advance the generator state equally on all ranks if so.
|
|
||||||
|
|
||||||
When using DTensor together with Pipeline Parallelism, ranks for each pipeline stage should use a distinct seed, and ranks within a pipeline stage should use the same seed.
|
|
||||||
|
|
||||||
DTensor's RNG infra is based on the philox based RNG algorithm, and supports any philox based backend (cuda, and other cuda-like devices), but unfortunately does not yet support the CPU backend.
|
|
||||||
|
|
||||||
## Debugging
|
## Debugging
|
||||||
|
|
||||||
```{eval-rst}
|
```{eval-rst}
|
||||||
|
|
|
||||||
|
|
@ -66,7 +66,7 @@ class TensorParallelRandomStateTests(DTensorTestBase):
|
||||||
# in the following way:
|
# in the following way:
|
||||||
# - within a tensor parallel group, the RNG is set with the same seed
|
# - within a tensor parallel group, the RNG is set with the same seed
|
||||||
# - across data parallel groups, the RNG is set with different seeds
|
# - across data parallel groups, the RNG is set with different seeds
|
||||||
torch.cuda.manual_seed(0)
|
torch.cuda.manual_seed(dp_rank)
|
||||||
|
|
||||||
# disable/enable parallel RNG feature
|
# disable/enable parallel RNG feature
|
||||||
if random._rng_tracker:
|
if random._rng_tracker:
|
||||||
|
|
@ -118,10 +118,14 @@ class TensorParallelRandomStateTests(DTensorTestBase):
|
||||||
|
|
||||||
# compare local shards across TP groups
|
# compare local shards across TP groups
|
||||||
def dp_weights_assert(tensor1, tensor2):
|
def dp_weights_assert(tensor1, tensor2):
|
||||||
# local weights shall be initialized the same across TP groups,
|
if enable_distribute_flag:
|
||||||
# and it doesn't matter whether DTensor's RNG infra is activated since all spmd ranks
|
# local weights shall be initialized the same across TP groups
|
||||||
# started with the same seed.
|
self.assertEqual(tensor1, tensor2)
|
||||||
self.assertEqual(tensor1, tensor2)
|
else:
|
||||||
|
# without the parallel RNG, weight initialization violates the TP setup:
|
||||||
|
# local weights are initialized differently across TP groups due to different
|
||||||
|
# random seeds set in data loading.
|
||||||
|
self.assertNotEqual(tensor1, tensor2)
|
||||||
|
|
||||||
self.check_gathered_tensors(
|
self.check_gathered_tensors(
|
||||||
dp_rank, dp_size, tensor_gather, dp_weights_assert
|
dp_rank, dp_size, tensor_gather, dp_weights_assert
|
||||||
|
|
|
||||||
|
|
@ -33,11 +33,6 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_generator_seed_for_device_type(device_type: str) -> int:
|
|
||||||
device_module = torch.get_device_module(device_type)
|
|
||||||
return device_module.get_rng_state()[:8].view(torch.int64).item()
|
|
||||||
|
|
||||||
|
|
||||||
class DistTensorRandomInitTest(DTensorTestBase):
|
class DistTensorRandomInitTest(DTensorTestBase):
|
||||||
def _run_init_op(self, init_op, *args, **kwargs):
|
def _run_init_op(self, init_op, *args, **kwargs):
|
||||||
device_mesh = self.build_device_mesh()
|
device_mesh = self.build_device_mesh()
|
||||||
|
|
@ -118,20 +113,22 @@ class DistTensorRandomInitTest(DTensorTestBase):
|
||||||
# (`torch.distributed.tensor._random._rng_tracker._manual_seed`)
|
# (`torch.distributed.tensor._random._rng_tracker._manual_seed`)
|
||||||
# (b) If we try to match the semantics of (a) with a user-supplied RNG, they may be very surprised to find that
|
# (b) If we try to match the semantics of (a) with a user-supplied RNG, they may be very surprised to find that
|
||||||
# their RNG object never advances its state after using it with DTensor.
|
# their RNG object never advances its state after using it with DTensor.
|
||||||
torch.manual_seed(55)
|
# torch.distributed.tensor._random._rng_tracker._manual_seed(55)
|
||||||
rng.manual_seed(55)
|
# rng.manual_seed(55)
|
||||||
torch.nn.init.uniform_(t1, 0.0, 1.0)
|
# torch.nn.init.uniform_(t1, 0.0, 1.0)
|
||||||
torch.nn.init.uniform_(t2, 0.0, 1.0, rng)
|
# torch.nn.init.uniform_(t2, 0.0, 1.0, rng)
|
||||||
self.assertEqual(t1.full_tensor(), t2.full_tensor())
|
# self.assertEqual(t1.full_tensor(), t2.full_tensor())
|
||||||
|
|
||||||
@with_comms
|
@with_comms
|
||||||
@skip_if_lt_x_gpu(4)
|
@skip_if_lt_x_gpu(4)
|
||||||
def test_meta_tensor_init(self):
|
def test_meta_tensor_init(self):
|
||||||
# test suite sets each rank's seed to the same value.
|
# test suite sets each rank's seed to the same value but in actual
|
||||||
# The DTensor random ops will use the same generator as the default one on the device.
|
# execution the default random seed will be different (a random value).
|
||||||
|
# The DTensor random ops will use the same random seed even though the
|
||||||
# Note: this behavior changed, and now the guideline is to set the same RNG seed on all SPMD ranks.
|
# torch random generator keeps different seeds on ranks. This ensures
|
||||||
torch.cuda.manual_seed(0)
|
# that Replicate DTensor will have the same initialized results
|
||||||
|
# across ranks.
|
||||||
|
torch.cuda.manual_seed(self.rank)
|
||||||
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||||
size = [1024, 2048]
|
size = [1024, 2048]
|
||||||
meta_dtensor = distribute_tensor(
|
meta_dtensor = distribute_tensor(
|
||||||
|
|
@ -150,7 +147,7 @@ class DistTensorRandomInitTest(DTensorTestBase):
|
||||||
self.assertTrue(random._rng_tracker.distribute_region_enabled)
|
self.assertTrue(random._rng_tracker.distribute_region_enabled)
|
||||||
|
|
||||||
# allgather the local tensors
|
# allgather the local tensors
|
||||||
gathered_local_tensors = funcol.all_gather_tensor(
|
local_tensor = funcol.all_gather_tensor(
|
||||||
dtensor.to_local(), gather_dim=0, group=(device_mesh, 0)
|
dtensor.to_local(), gather_dim=0, group=(device_mesh, 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -161,8 +158,7 @@ class DistTensorRandomInitTest(DTensorTestBase):
|
||||||
# other rank should have an identical local tensor
|
# other rank should have an identical local tensor
|
||||||
other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024)
|
other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
gathered_local_tensors[self_slice, :],
|
local_tensor[self_slice, :], local_tensor[other_slice, :]
|
||||||
gathered_local_tensors[other_slice, :],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test 2: disable the distribute region for RNG
|
# Test 2: disable the distribute region for RNG
|
||||||
|
|
@ -181,11 +177,11 @@ class DistTensorRandomInitTest(DTensorTestBase):
|
||||||
|
|
||||||
# compare with local tensors from other ranks
|
# compare with local tensors from other ranks
|
||||||
for other_rank in range(self.world_size):
|
for other_rank in range(self.world_size):
|
||||||
# the RNG result on each rank are the same even without the help of DTensor's RNG infra,
|
# the RNG result on each rank differs even they're supposed
|
||||||
# since the default RNG is the same across ranks.
|
# to be replicated
|
||||||
if self.rank != other_rank:
|
if self.rank != other_rank:
|
||||||
other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024)
|
other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024)
|
||||||
self.assertEqual(
|
self.assertNotEqual(
|
||||||
local_tensor[self_slice, :], local_tensor[other_slice, :]
|
local_tensor[self_slice, :], local_tensor[other_slice, :]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -311,12 +307,7 @@ class DistTensorRandomOpTest(DTensorTestBase):
|
||||||
# seed synchronization only happens after `manual_seed` or the first DTensor
|
# seed synchronization only happens after `manual_seed` or the first DTensor
|
||||||
# random op call
|
# random op call
|
||||||
dt.uniform_(0, 1)
|
dt.uniform_(0, 1)
|
||||||
|
self.assertEqual(seed_from_rank_0, random._rng_tracker.get_seed("parallel-rng"))
|
||||||
# We do not maintain the copy of the seed in dtensor, but we do mutate the global rng state
|
|
||||||
# since we now always pull it fresh from the local device generator
|
|
||||||
self.assertEqual(
|
|
||||||
seed_from_rank_0, get_generator_seed_for_device_type(self.device_type)
|
|
||||||
)
|
|
||||||
|
|
||||||
@with_comms
|
@with_comms
|
||||||
@skip_unless_torch_gpu
|
@skip_unless_torch_gpu
|
||||||
|
|
@ -335,13 +326,11 @@ class DistTensorRandomOpTest(DTensorTestBase):
|
||||||
manual_seed(self.rank, device_mesh)
|
manual_seed(self.rank, device_mesh)
|
||||||
# RNG tracker should already be initialized
|
# RNG tracker should already be initialized
|
||||||
self.assertTrue(random._rng_tracker is not None)
|
self.assertTrue(random._rng_tracker is not None)
|
||||||
self.assertEqual(
|
self.assertEqual(self.rank, random._rng_tracker.get_seed("parallel-rng"))
|
||||||
self.rank, get_generator_seed_for_device_type(self.device_type)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test 2: set same seed on different ranks
|
# Test 2: set same seed on different ranks
|
||||||
manual_seed(1234, device_mesh)
|
manual_seed(1234, device_mesh)
|
||||||
self.assertEqual(1234, get_generator_seed_for_device_type(self.device_type))
|
self.assertEqual(1234, random._rng_tracker.get_seed("parallel-rng"))
|
||||||
|
|
||||||
self.assertEqual(comm_mode.get_total_counts(), 0)
|
self.assertEqual(comm_mode.get_total_counts(), 0)
|
||||||
|
|
||||||
|
|
@ -374,10 +363,7 @@ class DistTensorRandomOpTest(DTensorTestBase):
|
||||||
|
|
||||||
# set the seed for each pipeline stage to 123 + pp_rank
|
# set the seed for each pipeline stage to 123 + pp_rank
|
||||||
manual_seed(123 + pp_rank, spmd_mesh)
|
manual_seed(123 + pp_rank, spmd_mesh)
|
||||||
# dtensor no longer stores a copy of the seed, but it mutates the device's generator so we can check that
|
self.assertEqual(123 + pp_rank, random._rng_tracker.get_seed("parallel-rng"))
|
||||||
self.assertEqual(
|
|
||||||
123 + pp_rank, get_generator_seed_for_device_type(self.device_type)
|
|
||||||
)
|
|
||||||
|
|
||||||
# mimic initializing a model weight sharded on the SPMD mesh
|
# mimic initializing a model weight sharded on the SPMD mesh
|
||||||
spmd_dtensor = torch.distributed.tensor.ones(
|
spmd_dtensor = torch.distributed.tensor.ones(
|
||||||
|
|
@ -462,15 +448,14 @@ class DistTensorRandomOpTest(DTensorTestBase):
|
||||||
self_slice = slice(4 * self.rank, 4 * self.rank + 4)
|
self_slice = slice(4 * self.rank, 4 * self.rank + 4)
|
||||||
for other_rank in range(self.world_size):
|
for other_rank in range(self.world_size):
|
||||||
if self.rank != other_rank:
|
if self.rank != other_rank:
|
||||||
# other rank should have a different local tensor for shard placement
|
# other rank should have an identical local tensor
|
||||||
other_slice = slice(4 * other_rank, 4 * other_rank + 4)
|
other_slice = slice(4 * other_rank, 4 * other_rank + 4)
|
||||||
self.assertNotEqual(
|
self.assertNotEqual(
|
||||||
local_tensor[self_slice, :],
|
local_tensor[self_slice, :],
|
||||||
local_tensor[other_slice, :],
|
local_tensor[other_slice, :],
|
||||||
)
|
)
|
||||||
|
|
||||||
# we should set manual seed to the same value on all SPMD ranks
|
torch.manual_seed(self.rank)
|
||||||
torch.manual_seed(0)
|
|
||||||
dtensor = fn(size, device_mesh=device_mesh, placements=[Replicate()])
|
dtensor = fn(size, device_mesh=device_mesh, placements=[Replicate()])
|
||||||
local_tensor = funcol.all_gather_tensor(
|
local_tensor = funcol.all_gather_tensor(
|
||||||
dtensor.to_local(), gather_dim=0, group=(device_mesh, 0)
|
dtensor.to_local(), gather_dim=0, group=(device_mesh, 0)
|
||||||
|
|
@ -480,7 +465,7 @@ class DistTensorRandomOpTest(DTensorTestBase):
|
||||||
self_slice = slice(4 * self.rank, 4 * self.rank + 4)
|
self_slice = slice(4 * self.rank, 4 * self.rank + 4)
|
||||||
for other_rank in range(self.world_size):
|
for other_rank in range(self.world_size):
|
||||||
if self.rank != other_rank:
|
if self.rank != other_rank:
|
||||||
# other rank should have an identical local tensor for replicate placement
|
# other rank should have an identical local tensor
|
||||||
other_slice = slice(4 * other_rank, 4 * other_rank + 4)
|
other_slice = slice(4 * other_rank, 4 * other_rank + 4)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
local_tensor[self_slice, :],
|
local_tensor[self_slice, :],
|
||||||
|
|
|
||||||
|
|
@ -2,18 +2,16 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||||
import contextlib
|
import contextlib
|
||||||
import warnings
|
import warnings
|
||||||
from logging import getLogger
|
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.distributed.device_mesh import _get_device_handle, DeviceMesh
|
from torch.distributed.device_mesh import _get_device_handle, DeviceMesh
|
||||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
||||||
from torch.distributed.tensor.placement_types import Shard
|
from torch.distributed.tensor.placement_types import Shard
|
||||||
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"is_rng_supported_mesh",
|
"is_rng_supported_mesh",
|
||||||
"manual_seed",
|
"manual_seed",
|
||||||
|
|
@ -77,31 +75,22 @@ def manual_seed(seed: int, device_mesh: DeviceMesh) -> None:
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# TODO: deprecate this API, but also need to ensure we disable broadcast for PP case, and that's currently
|
|
||||||
# bundled together with this API. See torchtitan/distributed/utils.py:set_determinism
|
|
||||||
# warnings.warn(
|
|
||||||
# "DTensor manual_seed() is deprecated, since DTensor no longer maintains a separate copy of generator state. "
|
|
||||||
# "Use `torch.manual_seed` instead"
|
|
||||||
# )
|
|
||||||
# Note: we still need to ensure setting `run_state_sync=False` to support the the pp case
|
|
||||||
|
|
||||||
# instantiate a RNG tracker if haven't. By default DTensor uses an
|
# instantiate a RNG tracker if haven't. By default DTensor uses an
|
||||||
# OffsetBasedRNGTracker to perform random operators.
|
# OffsetBasedRNGTracker to perform random operators.
|
||||||
global _rng_tracker
|
global _rng_tracker
|
||||||
if not _rng_tracker:
|
if not _rng_tracker:
|
||||||
_rng_tracker = OffsetBasedRNGTracker(device_mesh, run_state_sync=False)
|
_rng_tracker = OffsetBasedRNGTracker(device_mesh, run_state_sync=False)
|
||||||
|
|
||||||
if device_mesh.get_coordinate() is None:
|
# the current rank is in mesh
|
||||||
|
if device_mesh.get_coordinate() is not None:
|
||||||
|
_rng_tracker._manual_seed(seed)
|
||||||
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"manual_seed requires the current rank to be a part of the device mesh "
|
"manual_seed requires the current rank to be a part of the device mesh "
|
||||||
"otherwise DTensor RNG state on the rank will not be initialized and "
|
"otherwise DTensor RNG state on the rank will not be initialized and "
|
||||||
"the behavior of DTensor random ops is undefined."
|
"the behavior of DTensor random ops is undefined."
|
||||||
)
|
)
|
||||||
|
|
||||||
# DTensor no longer maintains a copy of rng state. manual seed on dtensor is the same thing
|
|
||||||
# as manual seed on torch.
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
|
|
||||||
|
|
||||||
class _RNGStateTracker:
|
class _RNGStateTracker:
|
||||||
"""
|
"""
|
||||||
|
|
@ -189,38 +178,16 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
|
||||||
f"CUDA/CUDA-like/XPU device. Got {self._device.type} instead."
|
f"CUDA/CUDA-like/XPU device. Got {self._device.type} instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
rng_state = self._get_device_state()
|
|
||||||
if run_state_sync:
|
|
||||||
# synchronize RNG state using rank 0's current one
|
|
||||||
torch.distributed.broadcast(rng_state, 0)
|
|
||||||
my_rng_state = self._get_device_state()
|
|
||||||
if not all(my_rng_state == rng_state):
|
|
||||||
logger.warning(
|
|
||||||
"DTensor is synchronizing RNG states of every rank with the state from rank 0. "
|
|
||||||
"This behavior is deprecated. "
|
|
||||||
"Please call `torch.manual_seed()` on every rank that participates in SPMD DTensor Operations with "
|
|
||||||
"the same seed. If using Pipeline Parallelism, each pipeling state would use a different seed, "
|
|
||||||
"but all ranks belonging to one pipeline stage would use the same seed."
|
|
||||||
)
|
|
||||||
self._set_device_state(rng_state)
|
|
||||||
|
|
||||||
def _get_device_state(self) -> torch.Tensor:
|
|
||||||
if self._device.type == "hpu":
|
if self._device.type == "hpu":
|
||||||
self._device_handle.set_rng_ctx("philox")
|
self._device_handle.set_rng_ctx("philox")
|
||||||
rng_state = self._device_handle.get_rng_state().to(self._device)
|
rng_state = self._device_handle.get_rng_state().to(self._device)
|
||||||
if self._device.type == "hpu":
|
if self._device.type == "hpu":
|
||||||
self._device_handle.unset_rng_ctx("philox")
|
self._device_handle.unset_rng_ctx("philox")
|
||||||
return rng_state
|
if run_state_sync:
|
||||||
|
# synchronize RNG state using rank 0's current one
|
||||||
|
dist.broadcast(rng_state, 0)
|
||||||
|
|
||||||
def _set_device_state(self, state: torch.Tensor):
|
self.rng_states["parallel-rng"] = rng_state.to("cpu")
|
||||||
# It seems that the underlying generator wants a cpu tensor but the dtensor code expects `_get_device_state`
|
|
||||||
# to convert to a 'device' tensor, probably because we may use it with our backend comms for sync/debug
|
|
||||||
# for now, we just convert back to cpu here to make sure it always works.
|
|
||||||
if self._device.type == "hpu":
|
|
||||||
self._device_handle.set_rng_ctx("philox")
|
|
||||||
self._device_handle.set_rng_state(state.to("cpu"))
|
|
||||||
if self._device.type == "hpu":
|
|
||||||
self._device_handle.unset_rng_ctx("philox")
|
|
||||||
|
|
||||||
def _manual_seed(self, parallel_seed: int) -> None:
|
def _manual_seed(self, parallel_seed: int) -> None:
|
||||||
self.set_seed("parallel-rng", parallel_seed)
|
self.set_seed("parallel-rng", parallel_seed)
|
||||||
|
|
@ -229,6 +196,7 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
|
||||||
def _distribute_region(
|
def _distribute_region(
|
||||||
self, spec: DTensorSpec, generator: Optional[torch.Generator] = None
|
self, spec: DTensorSpec, generator: Optional[torch.Generator] = None
|
||||||
):
|
):
|
||||||
|
g_name = "parallel-rng"
|
||||||
if generator is not None:
|
if generator is not None:
|
||||||
# This is a little hacky, but for any user-passed generator, we store its state under a unique key,
|
# This is a little hacky, but for any user-passed generator, we store its state under a unique key,
|
||||||
# not because we need to keep a copy of it but because its the easiest way to make it work with the
|
# not because we need to keep a copy of it but because its the easiest way to make it work with the
|
||||||
|
|
@ -236,10 +204,12 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
|
||||||
g_name = "user-passed-generator"
|
g_name = "user-passed-generator"
|
||||||
assert g_name not in self.rng_states
|
assert g_name not in self.rng_states
|
||||||
self.rng_states[g_name] = generator.get_state()
|
self.rng_states[g_name] = generator.get_state()
|
||||||
else:
|
# check if the parallel rng state has been synchronized or not
|
||||||
g_name = "parallel-rng"
|
if not self.rng_state_is_sync("parallel-rng"):
|
||||||
assert g_name not in self.rng_states
|
raise RuntimeError(
|
||||||
self.rng_states[g_name] = self._get_device_state().to("cpu")
|
"OffsetBasedRNGTracker requires the random state to be synchronized "
|
||||||
|
"before entering into a distribute region!"
|
||||||
|
)
|
||||||
|
|
||||||
if self.distribute_region_enabled:
|
if self.distribute_region_enabled:
|
||||||
if self._device.type == "hpu":
|
if self._device.type == "hpu":
|
||||||
|
|
@ -266,8 +236,6 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
|
||||||
# usage of that RNG (dtensor or non-dtensor), (b) drop it from our own cache so that if the user updates
|
# usage of that RNG (dtensor or non-dtensor), (b) drop it from our own cache so that if the user updates
|
||||||
# the seed value in their rng and uses it with DTensor again, we always use the latest value
|
# the seed value in their rng and uses it with DTensor again, we always use the latest value
|
||||||
generator.set_state(self.rng_states.pop(g_name))
|
generator.set_state(self.rng_states.pop(g_name))
|
||||||
else:
|
|
||||||
self._set_device_state(self.rng_states.pop(g_name))
|
|
||||||
|
|
||||||
def get_offset(self, name: str) -> int:
|
def get_offset(self, name: str) -> int:
|
||||||
if name not in self.rng_states:
|
if name not in self.rng_states:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user