import torch RPC_SPARSE = "rpc_sparse" RPC_DENSE = "rpc_dense" def sparse_tensor_to_rpc_format(sparse_tensor): r""" A helper function creates a list containing the indices, values, and size of a coalesced sparse tensor. Args: sparse_tensor (torch.Tensor): sparse_coo_tensor represented as a list """ sparse_tensor = sparse_tensor.coalesce() return [sparse_tensor.indices(), sparse_tensor.values(), sparse_tensor.size()] def sparse_rpc_format_to_tensor(sparse_rpc_format): r""" A helper function creates a sparse_coo_tensor from indices, values, and size. Args: sparse_rpc_format (list): sparse_coo_tensor represented as a list """ return torch.sparse_coo_tensor( sparse_rpc_format[0], sparse_rpc_format[1], sparse_rpc_format[2] ).coalesce() def process_bucket_with_remote_server(state, bucket): r""" Processes a gradient bucket passed by a DDP communication hook during .backward(). The method supports processing sparse and dense tensors. It records RPC future completion time metric for the trainer. Args: state (object): maintains state during the training process bucket (GradBucket): gradient bucket """ cref = state.cref tensor = bucket.buffer() if not cref.use_cuda_rpc: tensor = tensor.cpu() sparse = tensor.is_sparse if sparse: tensor = sparse_tensor_to_rpc_format(tensor) b_index = bucket.get_index() server_args = [cref.server_rref, state.batch_number, b_index, tensor] key = state.get_key(b_index) cref.record_start("hook_future_metric", key, RPC_SPARSE if sparse else RPC_DENSE) fut = cref.server_rref.rpc_async().average_gradient(*server_args) def callback(fut): cref.record_end("hook_future_metric", key) tensor = fut.wait() if type(tensor) is list: tensor = sparse_rpc_format_to_tensor(tensor) tensor = tensor.cuda(cref.rank) return [tensor] return fut.then(callback)