Avoid unnecessary clone in torch.cuda.set_rng_state (#149283)

Clone has performance issue according to f49c3eb6e6/megatron/core/tensor_parallel/random.py (L77-L80)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149283
Approved by: https://github.com/cyyever, https://github.com/Skylion007
This commit is contained in:
Yuxin Wu 2025-03-18 20:47:57 +00:00 committed by PyTorch MergeBot
parent cd5c13d8f0
commit d80a70b58a

View File

@ -5,7 +5,7 @@ from typing import Union
import torch
from torch import Tensor
from . import _lazy_call, _lazy_init, current_device, device_count
from . import _lazy_call, _lazy_init, current_device, device_count, is_initialized
__all__ = [
@ -59,8 +59,11 @@ def set_rng_state(
device (torch.device or int, optional): The device to set the RNG state.
Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
"""
with torch._C._DisableFuncTorch():
new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
if not is_initialized():
with torch._C._DisableFuncTorch():
# Clone the state because the callback will be triggered
# later when CUDA is lazy initialized.
new_state = new_state.clone(memory_format=torch.contiguous_format)
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
@ -71,7 +74,7 @@ def set_rng_state(
if idx is None:
idx = current_device()
default_generator = torch.cuda.default_generators[idx]
default_generator.set_state(new_state_copy)
default_generator.set_state(new_state)
_lazy_call(cb)