mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Avoid unnecessary clone in torch.cuda.set_rng_state (#149283)
Clone has performance issue according to f49c3eb6e6/megatron/core/tensor_parallel/random.py (L77-L80)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149283
Approved by: https://github.com/cyyever, https://github.com/Skylion007
This commit is contained in:
parent
cd5c13d8f0
commit
d80a70b58a
|
|
@ -5,7 +5,7 @@ from typing import Union
|
|||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from . import _lazy_call, _lazy_init, current_device, device_count
|
||||
from . import _lazy_call, _lazy_init, current_device, device_count, is_initialized
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -59,8 +59,11 @@ def set_rng_state(
|
|||
device (torch.device or int, optional): The device to set the RNG state.
|
||||
Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
|
||||
"""
|
||||
with torch._C._DisableFuncTorch():
|
||||
new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
|
||||
if not is_initialized():
|
||||
with torch._C._DisableFuncTorch():
|
||||
# Clone the state because the callback will be triggered
|
||||
# later when CUDA is lazy initialized.
|
||||
new_state = new_state.clone(memory_format=torch.contiguous_format)
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
elif isinstance(device, int):
|
||||
|
|
@ -71,7 +74,7 @@ def set_rng_state(
|
|||
if idx is None:
|
||||
idx = current_device()
|
||||
default_generator = torch.cuda.default_generators[idx]
|
||||
default_generator.set_state(new_state_copy)
|
||||
default_generator.set_state(new_state)
|
||||
|
||||
_lazy_call(cb)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user