mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
New PR for https://github.com/pytorch/pytorch/pull/75970 to be compatible with GHF. Pull Request resolved: https://github.com/pytorch/pytorch/pull/76374 Approved by: https://github.com/awgu
32 lines
1.1 KiB
Python
32 lines
1.1 KiB
Python
import torch.distributed as dist
|
|
|
|
def _verify_param_shape_across_processes(process_group, tensors, logger=None):
|
|
return dist._verify_params_across_processes(process_group, tensors, logger)
|
|
|
|
def _sync_params_and_buffers(
|
|
module,
|
|
process_group,
|
|
broadcast_bucket_size,
|
|
src,
|
|
params_and_buffers_to_ignore,
|
|
):
|
|
"""
|
|
Syncs ``module``'s parameters and buffers state so that all ranks contain
|
|
the same module state across all ranks. Note that this API assumes that all
|
|
parameter shapes are consistent before running the synchronization. This can
|
|
be checked with ``_verify_param_shape_across_processes``.
|
|
"""
|
|
module_states = []
|
|
for name, param in module.named_parameters():
|
|
if name not in params_and_buffers_to_ignore:
|
|
module_states.append(param.detach())
|
|
|
|
for name, buffer in module.named_buffers():
|
|
if name not in params_and_buffers_to_ignore:
|
|
module_states.append(buffer.detach())
|
|
|
|
if len(module_states) > 0:
|
|
dist._broadcast_coalesced(
|
|
process_group, module_states, broadcast_bucket_size, src
|
|
)
|