mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58003 adds trainer class DdpTrainer adds trainer class DdpSparseRpcTrainer adds server class ParameterServerBase adds server class AverageParameterServer adds experiment ddp_cpu_sparse_rpc_nccl_allreduce adds experiment ddp_cuda_sparse_rpc_nccl_allreduce quip document https://fb.quip.com/iQUtAeKIxWpF Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D29379696 Pulled By: gcramer23 fbshipit-source-id: 9cf5fb7398ba2fa3eb694afbddc4ed00d97f205f
68 lines
2.0 KiB
Python
68 lines
2.0 KiB
Python
import torch
|
|
|
|
RPC_SPARSE = "rpc_sparse"
|
|
RPC_DENSE = "rpc_dense"
|
|
|
|
|
|
def sparse_tensor_to_rpc_format(sparse_tensor):
|
|
r"""
|
|
A helper method creates a list containing the indices, values, and size
|
|
of a coalesced sparse tensor.
|
|
Args:
|
|
sparse_tensor (torch.Tensor): sparse_coo_tensor represented as a list
|
|
"""
|
|
sparse_tensor = sparse_tensor.coalesce()
|
|
return [sparse_tensor.indices(), sparse_tensor.values(), sparse_tensor.size()]
|
|
|
|
|
|
def sparse_rpc_format_to_tensor(sparse_rpc_format):
|
|
r"""
|
|
A helper method creates a sparse_coo_tensor from indices, values, and size.
|
|
Args:
|
|
sparse_rpc_format (list): sparse_coo_tensor represented as a list
|
|
"""
|
|
return torch.sparse_coo_tensor(
|
|
sparse_rpc_format[0], sparse_rpc_format[1], sparse_rpc_format[2]
|
|
).coalesce()
|
|
|
|
|
|
def process_bucket_with_remote_server(state, bucket):
|
|
r"""
|
|
Processes a gradient bucket passed by a DDP communication hook
|
|
during .backward(). The method supports processing sparse and dense
|
|
tensors. It records RPC future completion time metric for the trainer.
|
|
Args:
|
|
state (object): maintains state during the training process
|
|
bucket (GradBucket): gradient bucket
|
|
"""
|
|
cref = state.cref
|
|
tensor = bucket.get_tensor()
|
|
if not cref.use_cuda_rpc:
|
|
tensor = tensor.cpu()
|
|
sparse = tensor.is_sparse
|
|
if sparse:
|
|
tensor = sparse_tensor_to_rpc_format(tensor)
|
|
b_index = bucket.get_index()
|
|
server_args = [
|
|
cref.server_rref,
|
|
state.batch_number,
|
|
b_index,
|
|
tensor
|
|
]
|
|
key = state.get_key(b_index)
|
|
cref.record_hook_fut_start(
|
|
key,
|
|
RPC_SPARSE if sparse else RPC_DENSE
|
|
)
|
|
fut = cref.server_rref.rpc_async().average_gradient(*server_args)
|
|
|
|
def callback(fut):
|
|
cref.record_hook_fut_end(key)
|
|
tensor = fut.wait()
|
|
if type(tensor) is list:
|
|
tensor = sparse_rpc_format_to_tensor(tensor)
|
|
tensor = tensor.cuda(cref.rank)
|
|
return [tensor]
|
|
|
|
return fut.then(callback)
|