pytorch/benchmarks/distributed/rpc/parameter_server/utils.py
Garrett Cramer 4ed2d5d9bb ps sparse rpc (#58003)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58003

adds trainer class DdpTrainer
adds trainer class DdpSparseRpcTrainer
adds server class ParameterServerBase
adds server class AverageParameterServer
adds experiment ddp_cpu_sparse_rpc_nccl_allreduce
adds experiment ddp_cuda_sparse_rpc_nccl_allreduce

quip document https://fb.quip.com/iQUtAeKIxWpF

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D29379696

Pulled By: gcramer23

fbshipit-source-id: 9cf5fb7398ba2fa3eb694afbddc4ed00d97f205f
2021-06-24 17:21:49 -07:00

68 lines
2.0 KiB
Python

import torch
RPC_SPARSE = "rpc_sparse"
RPC_DENSE = "rpc_dense"
def sparse_tensor_to_rpc_format(sparse_tensor):
r"""
A helper method creates a list containing the indices, values, and size
of a coalesced sparse tensor.
Args:
sparse_tensor (torch.Tensor): sparse_coo_tensor represented as a list
"""
sparse_tensor = sparse_tensor.coalesce()
return [sparse_tensor.indices(), sparse_tensor.values(), sparse_tensor.size()]
def sparse_rpc_format_to_tensor(sparse_rpc_format):
r"""
A helper method creates a sparse_coo_tensor from indices, values, and size.
Args:
sparse_rpc_format (list): sparse_coo_tensor represented as a list
"""
return torch.sparse_coo_tensor(
sparse_rpc_format[0], sparse_rpc_format[1], sparse_rpc_format[2]
).coalesce()
def process_bucket_with_remote_server(state, bucket):
r"""
Processes a gradient bucket passed by a DDP communication hook
during .backward(). The method supports processing sparse and dense
tensors. It records RPC future completion time metric for the trainer.
Args:
state (object): maintains state during the training process
bucket (GradBucket): gradient bucket
"""
cref = state.cref
tensor = bucket.get_tensor()
if not cref.use_cuda_rpc:
tensor = tensor.cpu()
sparse = tensor.is_sparse
if sparse:
tensor = sparse_tensor_to_rpc_format(tensor)
b_index = bucket.get_index()
server_args = [
cref.server_rref,
state.batch_number,
b_index,
tensor
]
key = state.get_key(b_index)
cref.record_hook_fut_start(
key,
RPC_SPARSE if sparse else RPC_DENSE
)
fut = cref.server_rref.rpc_async().average_gradient(*server_args)
def callback(fut):
cref.record_hook_fut_end(key)
tensor = fut.wait()
if type(tensor) is list:
tensor = sparse_rpc_format_to_tensor(tensor)
tensor = tensor.cuda(cref.rank)
return [tensor]
return fut.then(callback)