Add set_device_map to TensorPipeOptions to support GPU args (#42637)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42637

This commit enables sending non-CPU tensors through RPC using
TensorPipe backend. Users can configure device mappings by calling
set_map_location on `TensorPipeRpcBackendOptions`. Internally,
the `init_rpc` API verifies the correctness of device mappings. It
will shutdown RPC if the check failed, or proceed and pass global
mappings to `TensorPipeAgent` if the check was successful. For serde,
we added a device indices field to TensorPipe read and write buffers,
which should be either empty (all tensors must be on CPU) or match
the tensors in order and number in the RPC message. This commit
does not yet avoid zero-copy, the tensor is always moved to CPU
on the sender and then moved to the specified device on the receiver.

Test Plan: Imported from OSS

Reviewed By: izdeby

Differential Revision: D23011572

Pulled By: mrshenli

fbshipit-source-id: 62b617eed91237d4e9926bc8551db78b822a1187
This commit is contained in:
Shen Li 2020-08-14 18:44:41 -07:00 committed by Facebook GitHub Bot
parent c84f78470b
commit 06aaf8c20d
11 changed files with 798 additions and 152 deletions

View File

@ -58,7 +58,11 @@ TEST(TensorpipeSerialize, Base) {
for (int i = 0; i < recvingTpMessage.payloads.size(); i++) {
tensorpipe::Message::Payload& srcPayload = sendingTpMessage.payloads[i];
tensorpipe::Message::Payload& dstPayload = recvingTpMessage.payloads[i];
memcpy(dstPayload.data, srcPayload.data, srcPayload.length);
if (srcPayload.length) {
// Empty vector's data() can return nullptr, use the length to avoid
// coying into nullptr
memcpy(dstPayload.data, srcPayload.data, srcPayload.length);
}
}
for (int i = 0; i < recvingTpMessage.tensors.size(); i++) {
tensorpipe::Message::Tensor& srcTensor = sendingTpMessage.tensors[i];

View File

@ -459,42 +459,22 @@ PyObject* rpc_init(PyObject* /* unused */) {
// Base class: torch.distributed.rpc.RpcBackendOptions.
py::class_<TensorPipeRpcBackendOptions>(
module,
"TensorPipeRpcBackendOptions",
rpcBackendOptions,
R"(
The backend options for
:class:`~torch.distributed.rpc.TensorPipeAgent`, derived from
:class:`~torch.distributed.rpc.RpcBackendOptions`.
Arguments:
num_worker_threads (int, optional): The number of threads in the
thread-pool used by
:class:`~torch.distributed.rpc.TensorPipeAgent` to execute
requests (default: 16).
rpc_timeout (float, optional): The default timeout, in seconds,
for RPC requests (default: 60 seconds). If the RPC has not
completed in this timeframe, an exception indicating so will
be raised. Callers can override this timeout for individual
RPCs in :meth:`~torch.distributed.rpc.rpc_sync` and
:meth:`~torch.distributed.rpc.rpc_async` if necessary.
init_method (str, optional): The URL to initialize the distributed
store used for rendezvous. It takes any value accepted for the
same argument of :meth:`~torch.distributed.init_process_group`
(default: ``env://``).
)")
module, "_TensorPipeRpcBackendOptionsBase", rpcBackendOptions)
.def(
py::init<
int,
optional<std::vector<std::string>>,
optional<std::vector<std::string>>,
float,
std::string>(),
std::string,
std::unordered_map<std::string, tensorpipe::DeviceMap>>(),
py::arg("num_worker_threads") = kDefaultNumWorkerThreads,
py::arg("_transports") = optional<std::vector<std::string>>(),
py::arg("_channels") = optional<std::vector<std::string>>(),
py::arg("rpc_timeout") = kDefaultRpcTimeoutSeconds,
py::arg("init_method") = kDefaultInitMethod)
py::arg("init_method") = kDefaultInitMethod,
py::arg("device_maps") =
std::unordered_map<std::string, tensorpipe::DeviceMap>())
.def_readwrite(
"num_worker_threads",
&TensorPipeRpcBackendOptions::numWorkerThreads,
@ -502,7 +482,12 @@ PyObject* rpc_init(PyObject* /* unused */) {
The number of threads in the thread-pool used by
:class:`~torch.distributed.rpc.TensorPipeAgent` to execute
requests.
)");
)")
.def_readwrite(
"device_maps",
&TensorPipeRpcBackendOptions::deviceMaps,
R"(The device map locations.)")
.def("set_device_map", &TensorPipeRpcBackendOptions::setDeviceMap);
module.attr("_DEFAULT_NUM_WORKER_THREADS") =
py::cast(kDefaultNumWorkerThreads);
@ -552,7 +537,11 @@ PyObject* rpc_init(PyObject* /* unused */) {
"get_worker_infos",
(std::vector<WorkerInfo>(TensorPipeAgent::*)() const) &
TensorPipeAgent::getWorkerInfos,
py::call_guard<py::gil_scoped_release>());
py::call_guard<py::gil_scoped_release>())
.def(
"_set_reverse_device_maps",
// intentionally not releasing GIL to avoid unnecessary context switch
&TensorPipeAgent::setReverseDeviceMaps);
#endif // USE_TENSORPIPE
@ -763,9 +752,7 @@ PyObject* rpc_init(PyObject* /* unused */) {
applications responsibility to make sure that the above assumption
always holds.
)");
module.def(
"_disable_jit_rref_pickle",
&disableJitRRefPickle);
module.def("_disable_jit_rref_pickle", &disableJitRRefPickle);
Py_RETURN_TRUE;
}

View File

@ -32,6 +32,44 @@ const std::string kServerActiveAsyncCalls = "agent.server_active_async_calls";
const std::string kRpcTimeoutErrorStr =
"RPC ran for more than set timeout ({} ms) and will now be marked with an error";
inline void checkCPUTensor(const torch::Tensor& tensor) {
TORCH_CHECK(
tensor.device() == at::kCPU,
"TensorPipeAgent only supports CPU tensors by default. Sending "
"GPU tensors using RPC requires explicitly configurations using "
"`set_device_map` on `TensorPipeRpcBackendOptions`. Got a tensor "
"with device ",
tensor.device(),
", but no device map is specified.");
}
std::vector<c10::DeviceIndex> getDevicesForTensors(
const std::string& remoteName,
const std::vector<torch::Tensor>& tensors,
const std::unordered_map<std::string, tensorpipe::DeviceMap>& deviceMaps) {
const auto workerIter = deviceMaps.find(remoteName);
if (workerIter == deviceMaps.end()) {
for (const auto& tensor : tensors) {
checkCPUTensor(tensor);
}
return {};
} else {
std::vector<c10::DeviceIndex> deviceIndices;
deviceIndices.reserve(tensors.size());
const auto& deviceMap = workerIter->second;
for (const auto& tensor : tensors) {
const auto deviceIter = deviceMap.find(tensor.device().index());
if (deviceIter == deviceMap.end()) {
checkCPUTensor(tensor);
deviceIndices.push_back(-1);
} else {
deviceIndices.push_back(deviceIter->second);
}
}
return deviceIndices;
}
}
} // namespace
C10_DEFINE_REGISTRY(TensorPipeTransportRegistry, TransportRegistration);
@ -407,7 +445,14 @@ void TensorPipeAgent::pipeWrite(
std::function<void(const tensorpipe::Error&)> fn) {
tensorpipe::Message tpMessage;
TensorpipeWriteBuffers tpBuffers;
std::tie(tpMessage, tpBuffers) = tensorpipeSerialize(std::move(rpcMessage));
const auto& deviceMaps =
rpcMessage.isRequest() ? opts_.deviceMaps : reverseDeviceMaps_;
auto devices = getDevicesForTensors(
pipe->getRemoteName(), rpcMessage.tensors(), deviceMaps);
std::tie(tpMessage, tpBuffers) = tensorpipeSerialize(
std::move(rpcMessage), std::move(devices));
pipe->write(
std::move(tpMessage),
[tpBuffers{
@ -438,16 +483,41 @@ void TensorPipeAgent::sendCompletedResponseMessage(
Message&& responseMessage = std::move(*futureResponseMessage).moveValue();
responseMessage.setId(messageId);
if (!error) {
for (const auto& tensor : responseMessage.tensors()) {
if (!tensor.device().is_cpu()) {
responseMessage = createExceptionResponse(
c10::str(
"TensorPipe RPC backend only supports CPU tensors, please ",
"move your tensors to CPU before sending them over RPC. Found ",
"tensor on device: ",
tensor.device()),
responseMessage.id());
break;
const auto& iter = reverseDeviceMaps_.find(pipe->getRemoteName());
if (iter == opts_.deviceMaps.end()) {
for (const auto& t : responseMessage.tensors()) {
if (!t.device().is_cpu()) {
responseMessage = createExceptionResponse(
c10::str(
"TensorPipe RPC backend only supports CPU tensors by default,"
" please move your tensors to CPU before sending them over "
"RPC, or call `set_device_map` on "
"`TensorPipeRpcBackendOptions` to explicitly configure "
"device mapping. Response device mapping is not available for "
"destination ",
pipe->getRemoteName(),
", but found tensor on device: ",
t.device()),
responseMessage.id());
break;
}
}
} else {
const auto& deviceMap = iter->second;
for (const auto& t : responseMessage.tensors()) {
if (!t.device().is_cpu() &&
deviceMap.find(t.device().index()) == deviceMap.end()) {
responseMessage = createExceptionResponse(
c10::str(
"TensorPipe RPC backend only supports CPU tensors by default."
" Response device mapping is not available for destination ",
pipe->getRemoteName(),
" for device ",
t.device(),
" but received a tensor on that device."),
responseMessage.id());
break;
}
}
}
@ -581,14 +651,6 @@ std::shared_ptr<FutureMessage> TensorPipeAgent::send(
throw std::runtime_error(err);
}
for (const auto& tensor : requestMessage.tensors()) {
TORCH_CHECK(
tensor.device().is_cpu(),
"TensorPipe RPC backend only supports CPU tensors, please move your ",
"tensors to CPU before sending them over RPC. Found tensor on device: ",
tensor.device());
}
const auto& url = findWorkerURL(toWorkerInfo);
std::unique_lock<std::mutex> lock(mutex_);

View File

@ -33,9 +33,11 @@ namespace channel {
class Context;
} // namespace channel
using DeviceMap = std::unordered_map<c10::DeviceIndex, c10::DeviceIndex>;
} // namespace tensorpipe
namespace torch {
namespace torch {
namespace distributed {
namespace rpc {
@ -65,11 +67,13 @@ struct TensorPipeRpcBackendOptions : public RpcBackendOptions {
optional<std::vector<std::string>> transports,
optional<std::vector<std::string>> channels,
float rpc_timeout,
std::string init_method)
std::string init_method,
std::unordered_map<std::string, tensorpipe::DeviceMap> device_maps = {})
: RpcBackendOptions(rpc_timeout, init_method),
numWorkerThreads(numWorkerThreads),
transports(std::move(transports)),
channels(std::move(channels)) {
channels(std::move(channels)),
deviceMaps(std::move(device_maps)) {
TORCH_CHECK(
numWorkerThreads > 0,
"num_worker_threads must be positive, got ",
@ -94,9 +98,23 @@ struct TensorPipeRpcBackendOptions : public RpcBackendOptions {
}
}
void setDeviceMap(
const std::string& workerName,
const tensorpipe::DeviceMap& deviceMap) {
auto iter = deviceMaps.find(workerName);
if (iter == deviceMaps.end()) {
deviceMaps[workerName] = deviceMap;
} else {
for (auto& entry : deviceMap) {
iter->second[entry.first] = entry.second;
}
}
}
int numWorkerThreads;
const optional<std::vector<std::string>> transports;
const optional<std::vector<std::string>> channels;
std::unordered_map<std::string, tensorpipe::DeviceMap> deviceMaps;
};
// Struct to track the network source metrics
@ -148,6 +166,11 @@ class TensorPipeAgent : public RpcAgent {
const WorkerInfo& getWorkerInfo(const std::string& workerName) const override;
const WorkerInfo& getWorkerInfo(worker_id_t workerId) const override;
std::vector<WorkerInfo> getWorkerInfos() const override;
void setReverseDeviceMaps(
const std::unordered_map<std::string, tensorpipe::DeviceMap>&
reverseDeviceMaps) {
reverseDeviceMaps_ = reverseDeviceMaps;
}
std::unordered_map<std::string, std::string> getMetrics() override;
@ -233,6 +256,7 @@ class TensorPipeAgent : public RpcAgent {
};
const TensorPipeRpcBackendOptions opts_;
std::unordered_map<std::string, tensorpipe::DeviceMap> reverseDeviceMaps_;
ThreadPool threadPool_;
std::shared_ptr<tensorpipe::Context> context_;

View File

@ -55,6 +55,15 @@ void processRemoteProfiledEvents(
// Add event list to the thread local profiler.
torch::autograd::profiler::addEventList(std::move(events));
}
inline c10::Device indexToDevice(c10::DeviceIndex index) {
if (index == -1) {
return c10::Device(at::kCPU);
} else {
return c10::Device(at::kCUDA, index);
}
}
} // namespace
const std::string kRPCErrorPrefix = std::string("RPCErr");
@ -465,22 +474,27 @@ constexpr int kTpMessageTypeIdx = 0;
constexpr int kTpMessageIdIdx = 1;
// Then comes the rpc::Message::payload();
constexpr int kTpMessagePayloadIdx = 2;
// Then comes the destination device indices for all tensors in the message.
constexpr int kTpMessageDevicesIdx = 3;
// Last comes the pickle of rpc::Message::tensors() (with the tensors themselves
// stored as, well, tensors in the tensorpipe::Message).
constexpr int kTpMessagePickleIdx = 3;
constexpr int kTpMessagePickleIdx = 4;
} // namespace
std::tuple<tensorpipe::Message, TensorpipeWriteBuffers> tensorpipeSerialize(
Message&& rpcMessage) {
Message&& rpcMessage,
std::vector<c10::DeviceIndex> deviceIndices) {
tensorpipe::Message tpMessage;
TensorpipeWriteBuffers buffers;
// Metadata
buffers.type = std::make_unique<MessageType>(rpcMessage.type());
buffers.id = std::make_unique<int64_t>(rpcMessage.id());
// kTpMessageTypeIdx = 0
tpMessage.payloads.push_back(
tensorpipe::Message::Payload{buffers.type.get(), sizeof(MessageType)});
// kTpMessageIdIdx = 1
tpMessage.payloads.push_back(
tensorpipe::Message::Payload{buffers.id.get(), sizeof(int64_t)});
@ -490,11 +504,30 @@ std::tuple<tensorpipe::Message, TensorpipeWriteBuffers> tensorpipeSerialize(
// it uses non-const pointers even though it doesn't modify them when writing.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
char* payloadPtr = const_cast<char*>(buffers.payload.data());
// kTpMessagePayloadIdx = 2
tpMessage.payloads.push_back(
tensorpipe::Message::Payload{payloadPtr, buffers.payload.size()});
// Device indices
buffers.deviceIndices = std::move(deviceIndices);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto indicesPtr = const_cast<c10::DeviceIndex*>(buffers.deviceIndices.data());
auto size = buffers.deviceIndices.size() * sizeof(c10::DeviceIndex);
// kTpMessageDevicesIdx = 3
tpMessage.payloads.push_back(tensorpipe::Message::Payload{indicesPtr, size});
// Tensors
buffers.tensors = cloneSparseTensors(rpcMessage.tensors()).vec();
if (buffers.deviceIndices.empty()) {
buffers.tensors = cloneSparseTensors(rpcMessage.tensors()).vec();
} else {
std::vector<torch::Tensor> tensors;
tensors.reserve(rpcMessage.tensors().size());
for (const auto& tensor : rpcMessage.tensors()) {
tensors.emplace_back(tensor.cpu());
}
buffers.tensors = cloneSparseTensors(tensors).vec();
}
torch::jit::Pickler pickler([&](const void* buf, size_t sz) -> size_t {
buffers.pickle.insert(
buffers.pickle.end(),
@ -505,6 +538,7 @@ std::tuple<tensorpipe::Message, TensorpipeWriteBuffers> tensorpipeSerialize(
pickler.protocol();
pickler.pushIValue(buffers.tensors);
pickler.stop();
// kTpMessagePickleIdx = 4
tpMessage.payloads.push_back(tensorpipe::Message::Payload{
buffers.pickle.data(), buffers.pickle.size()});
for (const auto& tensor : pickler.tensorData()) {
@ -534,8 +568,8 @@ TensorpipeReadBuffers tensorpipeAllocate(tensorpipe::Message& tpMessage) {
TensorpipeReadBuffers buffers;
TORCH_INTERNAL_ASSERT(
tpMessage.payloads.size() == 4,
"message expected to contain 4 payloads, whereas it contained ",
tpMessage.payloads.size() == 5,
"message expected to contain 5 payloads, whereas it contained ",
tpMessage.payloads.size(),
" payloads");
@ -564,6 +598,17 @@ TensorpipeReadBuffers tensorpipeAllocate(tensorpipe::Message& tpMessage) {
buffers.payload.resize(tpMessage.payloads[kTpMessagePayloadIdx].length);
tpMessage.payloads[kTpMessagePayloadIdx].data = buffers.payload.data();
auto deviceIndexSize = sizeof(c10::DeviceIndex);
TORCH_INTERNAL_ASSERT(
tpMessage.payloads[kTpMessageDevicesIdx].length % deviceIndexSize == 0,
"Invalid number of bytes for device index payload, total bytes: ",
tpMessage.payloads[kTpMessageDevicesIdx].length,
", device index size: ",
deviceIndexSize);
buffers.deviceIndices.resize(
tpMessage.payloads[kTpMessageDevicesIdx].length / deviceIndexSize);
tpMessage.payloads[kTpMessageDevicesIdx].data = buffers.deviceIndices.data();
buffers.pickle.resize(tpMessage.payloads[kTpMessagePickleIdx].length);
tpMessage.payloads[kTpMessagePickleIdx].data = buffers.pickle.data();
@ -606,6 +651,23 @@ Message tensorpipeDeserialize(
tensors.emplace_back(std::move(t));
}
if (!buffers.deviceIndices.empty()) {
TORCH_INTERNAL_ASSERT(
buffers.deviceIndices.size() == tensors.size(),
"Number of device indices must match the number of tensors in the "
"RPC message. But got ",
tensors.size(),
" tensors with ",
buffers.deviceIndices.size(),
" device indices.");
for (size_t i = 0; i < tensors.size(); ++i) {
auto index = buffers.deviceIndices[i];
if (tensors[i].device().index() != index) {
tensors[i] = tensors[i].to(indexToDevice(index));
}
}
}
return Message(
std::move(buffers.payload),
std::move(tensors),

View File

@ -1,5 +1,6 @@
#pragma once
#include <c10/core/Device.h>
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
#include <torch/csrc/jit/serialization/pickle.h>
#include <torch/csrc/utils/byte_order.h>
@ -63,6 +64,7 @@ struct TensorpipeWriteBuffers {
std::unique_ptr<MessageType> type;
std::unique_ptr<int64_t> id;
std::vector<char> payload;
std::vector<c10::DeviceIndex> deviceIndices;
std::vector<char> pickle;
// This contains the original tensors and the clones of the sparse tensors.
std::vector<torch::Tensor> tensors;
@ -78,6 +80,7 @@ struct TensorpipeReadBuffers {
std::unique_ptr<MessageType> type;
std::unique_ptr<int64_t> id;
std::vector<char> payload;
std::vector<c10::DeviceIndex> deviceIndices;
std::vector<char> pickle;
std::vector<c10::DataPtr> tensors;
};
@ -85,7 +88,9 @@ struct TensorpipeReadBuffers {
// Convert an RPC message into a TensorPipe message, plus a holder to all the
// data that must be kept alive while the write is performed asynchronously.
TORCH_API std::tuple<tensorpipe::Message, TensorpipeWriteBuffers>
tensorpipeSerialize(Message&& rpcMessage);
tensorpipeSerialize(
Message&& rpcMessage,
std::vector<c10::DeviceIndex> devices = {});
// Allocate the buffers that will hold the incoming data. They will be managed
// by the returned holder, which must be kept alive until the asynchronous read

View File

@ -1,7 +1,5 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import numbers
import torch
import torch.distributed as dist
import threading
@ -17,19 +15,26 @@ if is_available() and not torch._C._rpc_init():
raise RuntimeError("Failed to initialize torch.distributed.rpc")
if is_available():
from . import api, backend_registry, functions, _set_profiler_node_id
from . import (
_enable_jit_rref_pickle,
_disable_jit_rref_pickle,
_enable_jit_rref_pickle,
_set_and_start_rpc_agent,
) # noqa: F401
from .api import * # noqa: F401
from .options import TensorPipeRpcBackendOptions # noqa: F401
from .backend_registry import BackendType
from .server_process_global_profiler import (
_server_process_global_profile,
)
import torch.distributed.autograd as dist_autograd
import numbers
def init_rpc(
name,
backend=BackendType.PROCESS_GROUP,
@ -99,7 +104,52 @@ if is_available():
_set_profiler_node_id(rank)
# Initialize RPC.
api._init_rpc_backend(backend, store, name, rank, world_size, rpc_backend_options)
_init_rpc_backend(backend, store, name, rank, world_size, rpc_backend_options)
def _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options):
type_mapping = {
backend: backend_registry.BackendType,
store: dist.Store,
name: str,
rank: numbers.Integral,
world_size: numbers.Integral,
rpc_backend_options: RpcBackendOptions,
}
for arg, arg_type in type_mapping.items():
if not isinstance(arg, arg_type):
raise RuntimeError(
"Argument {} must be of type {} but got type {}".format(
arg, arg_type, type(arg)
)
)
def _init_rpc_backend(
backend=backend_registry.BackendType.PROCESS_GROUP,
store=None,
name=None,
rank=-1,
world_size=-1,
rpc_backend_options=None,
):
_validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options)
if _is_current_rpc_agent_set():
raise RuntimeError("RPC is already initialized")
# Initialize RPC.
rpc_agent = backend_registry.init_backend(
backend,
store=store,
name=name,
rank=rank,
world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
api._init_rpc_states(rpc_agent)
@api._require_initialized

View File

@ -3,17 +3,14 @@ import contextlib
import functools
import inspect
import logging
import numbers
import threading
from typing import Generic, TypeVar
import torch
import torch.distributed as dist
from . import (
PyRRef,
RemoteProfilerManager,
RpcBackendOptions,
WorkerInfo,
_cleanup_python_rpc_handler,
_delete_all_user_and_unforked_owner_rrefs,
@ -28,7 +25,6 @@ from . import (
_is_current_rpc_agent_set,
_reset_current_rpc_agent,
_set_and_start_rpc_agent,
backend_registry,
)
from .internal import (
@ -108,6 +104,16 @@ _all_gather_sequence_id = 0
_all_gather_sequence_id_to_states = collections.defaultdict(AllGatherStates)
def _init_rpc_states(agent):
worker_infos = agent.get_worker_infos()
global _ALL_WORKER_NAMES
_ALL_WORKER_NAMES = {worker_info.name for worker_info in worker_infos}
# NB: backend implementation might have already set the rpc_agent.
if not _is_current_rpc_agent_set():
_set_and_start_rpc_agent(agent)
def _gather_to_leader(sequence_id, worker_name, obj):
with _all_gather_dict_lock:
assert (
@ -287,38 +293,6 @@ def shutdown(graceful=True):
_reset_current_rpc_agent()
# TODO: add a context manager to wrap _init_rpc_backend and shutdown
def _init_rpc_backend(
backend=backend_registry.BackendType.PROCESS_GROUP,
store=None,
name=None,
rank=-1,
world_size=-1,
rpc_backend_options=None,
):
_validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options)
if _is_current_rpc_agent_set():
raise RuntimeError("RPC is already initialized")
# Initialize RPC.
rpc_agent = backend_registry.init_backend(
backend,
store=store,
name=name,
rank=rank,
world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
worker_infos = rpc_agent.get_worker_infos()
global _ALL_WORKER_NAMES
_ALL_WORKER_NAMES = {worker_info.name for worker_info in worker_infos}
_set_and_start_rpc_agent(rpc_agent)
@_require_initialized
def get_worker_info(worker_name=None):
r"""
@ -350,24 +324,6 @@ def _to_worker_info(name_or_info):
raise ValueError("Cannot get WorkerInfo from name {}".format(name_or_info))
def _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options):
type_mapping = {
backend: backend_registry.BackendType,
store: dist.Store,
name: str,
rank: numbers.Integral,
world_size: numbers.Integral,
rpc_backend_options: RpcBackendOptions,
}
for arg, arg_type in type_mapping.items():
if not isinstance(arg, arg_type):
raise RuntimeError(
"Argument {} must be of type {} but got type {}".format(
arg, arg_type, type(arg)
)
)
T = TypeVar("T")
GenericWithOneTypeVar = Generic[T]

View File

@ -4,8 +4,10 @@ import collections
from datetime import timedelta
import enum
import torch
import torch.distributed as dist
from . import api
from . import constants as rpc_constants
@ -171,6 +173,57 @@ def _tensorpipe_construct_rpc_backend_options_handler(
)
# detect if any worker has invalid device_map configurations, and return
# names of failed workers
def _tensorpipe_check_device_maps(agent, device_maps):
if device_maps is None:
device_maps = {}
def check_one_worker(name, device_maps, all_device_counts):
device_count = all_device_counts[name]
wrong_worker_names = set(device_maps) - set(all_device_counts)
if wrong_worker_names:
raise ValueError(f"Wrong worker names: {wrong_worker_names}")
for worker_name in all_device_counts:
remote_device_count = all_device_counts[worker_name]
if worker_name in device_maps:
device_map = device_maps[worker_name]
key_set = set(device_map.keys())
val_set = set(device_map.values())
if not all([
len(device_map) == len(key_set),
len(device_map) == len(val_set), # check 1-to-1 mapping
min(key_set) >= 0,
max(key_set) < device_count, # check local range
min(val_set) >= 0,
max(val_set) < remote_device_count # check remote range
]):
raise ValueError(
f"Invalid device_map configuration on {name}:\n"
f"device_maps = {device_maps}"
)
gathered = api._all_gather([torch.cuda.device_count(), device_maps])
all_device_counts = {name: gathered[name][0] for name in gathered}
all_device_maps = {name: gathered[name][1] for name in gathered}
for worker_name in all_device_maps:
worker_device_maps = all_device_maps[worker_name]
check_one_worker(worker_name, worker_device_maps, all_device_counts)
# passed all checked, construct reverse mapping for return values
reverse_device_maps = {}
local_name = api.get_worker_info().name
for worker_name in all_device_maps:
remote_device_maps = all_device_maps[worker_name]
if local_name in remote_device_maps:
remote_device_map = remote_device_maps[local_name]
reverse_device_maps[worker_name] = {
remote_device_map[k]: k for k in remote_device_map
}
agent._set_reverse_device_maps(reverse_device_maps)
def _tensorpipe_init_backend_handler(store, name, rank, world_size, rpc_backend_options):
from . import TensorPipeRpcBackendOptions
from . import TensorPipeAgent
@ -194,10 +247,20 @@ def _tensorpipe_init_backend_handler(store, name, rank, world_size, rpc_backend_
group = _init_process_group(store, rank, world_size)
# TODO: add try-except and destroy _agent in all processes if any fails.
return TensorPipeAgent(
agent = TensorPipeAgent(
store, name, rank, world_size, group, rpc_backend_options
)
api._init_rpc_states(agent)
try:
_tensorpipe_check_device_maps(agent, rpc_backend_options.device_maps)
except Exception:
api.shutdown()
raise
return agent
register_backend(
"TENSORPIPE",

View File

@ -0,0 +1,118 @@
from . import _TensorPipeRpcBackendOptionsBase
from . import constants as rpc_contants
import torch
from typing import Dict, List
class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
r"""
The backend options for
:class:`~torch.distributed.rpc.TensorPipeAgent`, derived from
:class:`~torch.distributed.rpc.RpcBackendOptions`.
Arguments:
num_worker_threads (int, optional): The number of threads in the
thread-pool used by
:class:`~torch.distributed.rpc.TensorPipeAgent` to execute
requests (default: 16).
rpc_timeout (float, optional): The default timeout, in seconds,
for RPC requests (default: 60 seconds). If the RPC has not
completed in this timeframe, an exception indicating so will
be raised. Callers can override this timeout for individual
RPCs in :meth:`~torch.distributed.rpc.rpc_sync` and
:meth:`~torch.distributed.rpc.rpc_async` if necessary.
init_method (str, optional): The URL to initialize the distributed
store used for rendezvous. It takes any value accepted for the
same argument of :meth:`~torch.distributed.init_process_group`
(default: ``env://``).
device_maps (Dict[str, Dict]): Device placement mappings from this
worker to the callee. Key is the callee worker name and value the
dictionary (``Dict`` of ``int``, ``str``, or ``torch.device``) that
maps this worker's device to the callee worker's device to the
callee worker's device. (default: ``None``)
"""
def __init__(
self,
*,
num_worker_threads: int = rpc_contants.DEFAULT_NUM_WORKER_THREADS,
rpc_timeout: float = rpc_contants.DEFAULT_RPC_TIMEOUT_SEC,
init_method: str = rpc_contants.DEFAULT_INIT_METHOD,
device_maps: Dict = None,
_transports: List = None,
_channels: List = None,
):
super().__init__(
num_worker_threads,
_transports,
_channels,
rpc_timeout,
init_method,
device_maps if device_maps else {}
)
def set_device_map(self, to: str, device_map: Dict):
r"""
Set device mapping between each RPC caller and callee pair. This
function can be called multiple times to incrementally add
device placement configurations.
Arguments:
worker_name (str): Callee name.
device_map (Dict of int, str, or torch.device): Device placement
mappings from this worker to the callee. This map must be
invertible.
Example::
>>> # both workers
>>> def add(x, y):
>>> print(x) # tensor([1., 1.], device='cuda:1')
>>> return x + y, (x + y).to(2)
>>>
>>> # on worker 0
>>> options = TensorPipeRpcBackendOptions(
>>> num_worker_threads=8,
>>> device_maps={"worker1": {0, 1}}
>>> # maps worker0's cuda:0 to worker1's cuda:1
>>> )
>>> options.set_device_map("worker1", {1, 2})
>>> # maps worker0's cuda:1 to worker1's cuda:2
>>>
>>> rpc.init_rpc(
>>> "worker0",
>>> rank=0,
>>> world_size=2
>>> backend=rpc.BackendType.TENSORPIPE,
>>> rpc_backend_options=options
>>> )
>>>
>>> x = torch.ones(2)
>>> rets = rpc.rpc_sync("worker1", add, args=(x.to(0), 1))
>>> # The first argument will be moved to cuda:1 on worker1. When
>>> # sending the return value back, it will follow the invert of
>>> # the device map, and hence will be moved back to cuda:0 and
>>> # cuda:1 on worker0
>>> print(rets[0]) # tensor([2., 2.], device='cuda:0')
>>> print(rets[0]) # tensor([2., 2.], device='cuda:1')
"""
device_index_map = {}
curr_device_maps = super().device_maps
for k in device_map:
v = device_map[k]
k, v = torch.device(k), torch.device(v)
if k.type != 'cuda' or v.type != 'cuda':
raise ValueError(
"`set_device_map` only supports CUDA devices, "
f"but got device pair {k}: {v}"
)
if to in curr_device_maps and k.index in curr_device_maps[to]:
curr_v = super().device_maps[to][k.index]
if curr_v != v.index:
raise ValueError(
"`set_device_map` only supports 1-to-1 mapping, "
f"trying to map {k} to {v} and {curr_v}"
)
device_index_map[k.index] = v.index
super().set_device_map(to, device_index_map)

View File

@ -674,7 +674,7 @@ class RpcTest(RpcAgentTestFixture):
self.init_method, rank=self.rank, world_size=self.world_size
)
)
rpc.api._init_rpc_backend(
rpc._init_rpc_backend(
backend=self.rpc_backend,
store=store,
name="duplicate_name",
@ -2514,38 +2514,6 @@ class RpcTest(RpcAgentTestFixture):
def _gpu_tensor_list_arg(tensor_list):
return torch.rand(3, 3)
@skip_if_lt_x_gpu(2)
@dist_init
def test_cuda(self):
dst = worker_name((self.rank + 1) % self.world_size)
t1 = torch.rand(3, 3).cuda(0)
t2 = torch.rand(3, 3).cuda(1)
t3 = torch.rand(3, 3)
# cuda tensors as args fail.
with self.assertRaisesRegex(RuntimeError, "RPC backend only supports CPU tensors.*Found tensor on device: cuda:0"):
rpc.rpc_sync(dst, torch.add, args=(t1, t2))
# mix of cpu and cuda tensors as args fail.
with self.assertRaisesRegex(RuntimeError, "RPC backend only supports CPU tensors.*Found tensor on device: cuda:0"):
rpc.rpc_sync(dst, torch.add, args=(t1, t3))
# gpu tensor list as args fails.
with self.assertRaisesRegex(RuntimeError, "RPC backend only supports CPU tensors.*Found tensor on device: cuda:0"):
rpc.rpc_sync(dst, RpcTest._gpu_tensor_list_arg, args=([t1, t2]))
# cuda tensors as return values fail.
with self.assertRaisesRegex(RuntimeError, "RPC backend only supports CPU tensors.*Found tensor on device: cuda:0"):
rpc.rpc_sync(dst, RpcTest._return_gpu_tensor, args=())
# cuda tensors as a list of return value fails
with self.assertRaisesRegex(RuntimeError, "RPC backend only supports CPU tensors.*Found tensor on device: cuda:0"):
rpc.rpc_sync(dst, RpcTest._return_gpu_tensor_list, args=())
# Sending to self should fail too.
with self.assertRaisesRegex(RuntimeError, "RPC backend only supports CPU tensors.*Found tensor on device: cuda:0"):
rpc.rpc_sync(worker_name(self.rank), torch.add, args=(t1, t2))
def _create_rref(self):
owner_rank = (self.rank + 2) % self.world_size
return rpc.remote(
@ -3156,6 +3124,39 @@ class RpcTest(RpcAgentTestFixture):
class ProcessGroupAgentRpcTest(RpcAgentTestFixture):
@skip_if_lt_x_gpu(2)
@dist_init
def test_cuda(self):
dst = worker_name((self.rank + 1) % self.world_size)
t1 = torch.rand(3, 3).cuda(0)
t2 = torch.rand(3, 3).cuda(1)
t3 = torch.rand(3, 3)
# cuda tensors as args fail.
with self.assertRaisesRegex(RuntimeError, "RPC backend only supports CPU tensors.*Found tensor on device: cuda:0"):
rpc.rpc_sync(dst, torch.add, args=(t1, t2))
# mix of cpu and cuda tensors as args fail.
with self.assertRaisesRegex(RuntimeError, "RPC backend only supports CPU tensors.*Found tensor on device: cuda:0"):
rpc.rpc_sync(dst, torch.add, args=(t1, t3))
# gpu tensor list as args fails.
with self.assertRaisesRegex(RuntimeError, "RPC backend only supports CPU tensors.*Found tensor on device: cuda:0"):
rpc.rpc_sync(dst, RpcTest._gpu_tensor_list_arg, args=([t1, t2]))
# cuda tensors as return values fail.
with self.assertRaisesRegex(RuntimeError, "RPC backend only supports CPU tensors.*Found tensor on device: cuda:0"):
rpc.rpc_sync(dst, RpcTest._return_gpu_tensor, args=())
# cuda tensors as a list of return value fails
with self.assertRaisesRegex(RuntimeError, "RPC backend only supports CPU tensors.*Found tensor on device: cuda:0"):
rpc.rpc_sync(dst, RpcTest._return_gpu_tensor_list, args=())
# Sending to self should fail too.
with self.assertRaisesRegex(RuntimeError, "RPC backend only supports CPU tensors.*Found tensor on device: cuda:0"):
rpc.rpc_sync(worker_name(self.rank), torch.add, args=(t1, t2))
def test_single_threaded_rref_owner(self):
# We need a process group in order to perform a barrier at the end.
dist.init_process_group(
@ -3771,3 +3772,317 @@ class TensorPipeAgentRpcTest(RpcAgentTestFixture):
num_worker_threads=self.rpc_backend_options.num_worker_threads,
rpc_timeout=timeout,
)
def _test_device_maps(self, options, errMsg="Invalid device_map"):
with self.assertRaisesRegex(ValueError, errMsg):
rpc.init_rpc(
name=worker_name(self.rank),
backend=self.rpc_backend,
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=options,
)
self.assertFalse(rpc.api._is_current_rpc_agent_set())
@skip_if_lt_x_gpu(2)
def test_device_maps_wrong_worker_name(self):
options = self.rpc_backend_options
options.set_device_map("none_exist", {0: 1})
self._test_device_maps(options, "Wrong worker names")
@skip_if_lt_x_gpu(1)
def test_device_maps_invalid_max_local_device(self):
options = self.rpc_backend_options
dst = worker_name((self.rank + 1) % self.world_size)
options.set_device_map(dst, {torch.cuda.device_count(): 0})
self._test_device_maps(options)
@skip_if_lt_x_gpu(1)
def test_device_maps_invalid_max_remote_device(self):
options = self.rpc_backend_options
dst = worker_name((self.rank + 1) % self.world_size)
options.set_device_map(dst, {0: torch.cuda.device_count()})
self._test_device_maps(options)
@skip_if_lt_x_gpu(2)
def test_device_maps_many_to_one(self):
options = self.rpc_backend_options
dst = worker_name((self.rank + 1) % self.world_size)
options.set_device_map(dst, {1: 0})
options.set_device_map(dst, {0: 0})
self._test_device_maps(options)
@skip_if_lt_x_gpu(2)
def test_device_maps_one_to_many(self):
if self.rank == 0:
options = self.rpc_backend_options
dst = worker_name((self.rank + 1) % self.world_size)
options.set_device_map(dst, {0: 1})
with self.assertRaisesRegex(
ValueError, "`set_device_map` only supports 1-to-1 mapping"
):
options.set_device_map(dst, {0: 0})
@skip_if_lt_x_gpu(1)
def test_device_maps_invalid_min_device(self):
options = self.rpc_backend_options
dst = worker_name((self.rank + 1) % self.world_size)
with self.assertRaisesRegex(
RuntimeError, "Device index must not be negative"
):
options.set_device_map(dst, {-1: 0})
with self.assertRaisesRegex(
RuntimeError, "Device index must not be negative"
):
options.set_device_map(dst, {0: -1})
@staticmethod
def _gpu_add(x, y):
if all([x.is_cuda, x.device.index == 1, y.is_cuda, y.device.index == 1]):
return (x + y).to(0)
else:
raise ValueError("Wrong device affinity")
@skip_if_lt_x_gpu(2)
def test_device_maps_gpu(self):
options = self.rpc_backend_options
dst = worker_name((self.rank + 1) % self.world_size)
options.set_device_map(dst, {0: 1, 1: 0})
rpc.init_rpc(
name=worker_name(self.rank),
backend=self.rpc_backend,
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=options,
)
ret = rpc.rpc_sync(
dst,
TensorPipeAgentRpcTest._gpu_add,
args=(torch.zeros(2).to(0), torch.ones(2).to(0))
)
self.assertEqual(ret.device, torch.device(1))
self.assertEqual(ret, (torch.zeros(2) + torch.ones(2)).to(1))
rpc.shutdown()
@staticmethod
def _gpu_add_multi_gpu(x, y):
if all([x.is_cuda, x.device.index == 0, y.is_cuda, y.device.index == 1]):
return x + y.to(0), x.to(1) - y
else:
raise ValueError("Wrong device affinity")
def _test_device_maps_multi_gpu(self, dst):
options = self.rpc_backend_options
options.set_device_map(dst, {1: 0})
options.set_device_map(dst, {0: 1})
rpc.init_rpc(
name=worker_name(self.rank),
backend=self.rpc_backend,
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=options,
)
rets = rpc.rpc_sync(
dst,
TensorPipeAgentRpcTest._gpu_add_multi_gpu,
args=(torch.zeros(2).to(1), torch.ones(2).to(0))
)
self.assertEqual(rets[0].device, torch.device(1))
self.assertEqual(rets[1].device, torch.device(0))
self.assertEqual(rets[0], (torch.zeros(2) + torch.ones(2)).to(1))
self.assertEqual(rets[1], (torch.zeros(2) - torch.ones(2)).to(0))
rpc.shutdown()
@skip_if_lt_x_gpu(2)
def test_device_maps_multi_gpu(self):
dst = worker_name((self.rank + 1) % self.world_size)
self._test_device_maps_multi_gpu(dst)
@skip_if_lt_x_gpu(2)
def test_device_maps_multi_gpu_self(self):
dst = worker_name(self.rank)
self._test_device_maps_multi_gpu(dst)
@staticmethod
def _gpu_add_return_to_gpu(x, y):
if x.device.type == 'cpu' and y.device.type == 'cpu':
return (x + y).to(0), (x - y).to(1), (x * y).to(2), (x / y).to(3)
else:
raise ValueError("Wrong device affinity")
@skip_if_lt_x_gpu(2)
def test_device_maps_in_options(self):
dst = worker_name((self.rank + 1) % self.world_size)
options = self.rpc_backend_options
rpc.init_rpc(
name=worker_name(self.rank),
backend=self.rpc_backend,
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=rpc.TensorPipeRpcBackendOptions(
init_method=options.init_method,
num_worker_threads=options.num_worker_threads,
device_maps={dst: {0: 1, 1: 0}}
)
)
rets = rpc.rpc_sync(
dst,
TensorPipeAgentRpcTest._gpu_add_multi_gpu,
args=(torch.zeros(2).to(1), torch.ones(2).to(0))
)
self.assertEqual(rets[0].device, torch.device(1))
self.assertEqual(rets[1].device, torch.device(0))
self.assertEqual(rets[0], (torch.zeros(2) + torch.ones(2)).to(1))
self.assertEqual(rets[1], (torch.zeros(2) - torch.ones(2)).to(0))
rpc.shutdown()
def _test_device_maps_return_to_gpu(self, dst):
options = self.rpc_backend_options
options.set_device_map(dst, {0: 1})
options.set_device_map(dst, {1: 2})
options.set_device_map(dst, {2: 3})
options.set_device_map(dst, {3: 0})
rpc.init_rpc(
name=worker_name(self.rank),
backend=self.rpc_backend,
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=options,
)
rets = rpc.rpc_sync(
dst,
TensorPipeAgentRpcTest._gpu_add_return_to_gpu,
args=(torch.zeros(2), torch.ones(2))
)
for i in range(len(rets)):
self.assertEqual(rets[i].device, torch.device((3 + i) % 4))
self.assertEqual(rets[0], (torch.zeros(2) + torch.ones(2)).to(3))
self.assertEqual(rets[1], (torch.zeros(2) - torch.ones(2)).to(0))
self.assertEqual(rets[2], (torch.zeros(2) * torch.ones(2)).to(1))
self.assertEqual(rets[3], (torch.zeros(2) / torch.ones(2)).to(2))
rpc.shutdown()
@skip_if_lt_x_gpu(4)
def test_device_maps_return_to_gpu(self):
dst = worker_name((self.rank + 1) % self.world_size)
self._test_device_maps_return_to_gpu(dst)
@skip_if_lt_x_gpu(4)
def test_device_maps_return_to_gpu_self(self):
dst = worker_name(self.rank)
self._test_device_maps_return_to_gpu(dst)
@staticmethod
def _add_to_gpu(x, y):
return (x + y).to(0)
def _test_device_maps_missing_config(self, mode):
dst = worker_name((self.rank + 1) % self.world_size)
errMsg = (
"TensorPipeAgent only supports CPU tensors by default.*"
"`set_device_map` on `TensorPipeRpcBackendOptions`"
)
with self.assertRaisesRegex(RuntimeError, errMsg):
if mode == RPCExecMode.SYNC:
rpc.rpc_sync(dst, torch.add, args=(torch.zeros(2).to(0), 1))
elif mode == RPCExecMode.REMOTE:
rpc.remote(dst, torch.add, args=(torch.zeros(2).to(0), 1)).to_here()
else:
raise ValueError(f"unexpected mode {mode}")
# make sure RPC is still functioning
ret = rpc.rpc_sync(dst, torch.add, args=(torch.ones(2), 1))
self.assertEqual(ret, torch.ones(2) + 1)
def _test_device_maps_missing_config_response(self, mode):
dst = worker_name((self.rank + 1) % self.world_size)
errMsg = "Response device mapping is not available"
with self.assertRaisesRegex(RuntimeError, errMsg):
if mode == RPCExecMode.SYNC:
rpc.rpc_sync(
dst,
TensorPipeAgentRpcTest._add_to_gpu,
args=(torch.zeros(2), 1)
)
elif mode == RPCExecMode.REMOTE:
rpc.remote(
dst,
TensorPipeAgentRpcTest._add_to_gpu,
args=(torch.zeros(2), 1)
).to_here()
else:
raise ValueError(f"unexpected mode {mode}")
# make sure RPC is still functioning
ret = rpc.rpc_sync(dst, torch.add, args=(torch.ones(2), 1))
self.assertEqual(ret, torch.ones(2) + 1)
@skip_if_lt_x_gpu(1)
@dist_init
def test_device_maps_missing_config(self):
self._test_device_maps_missing_config(RPCExecMode.SYNC)
@skip_if_lt_x_gpu(1)
@dist_init
def test_device_maps_missing_config_loop(self):
for _ in range(self.rpc_backend_options.num_worker_threads + 5):
self._test_device_maps_missing_config(RPCExecMode.SYNC)
@skip_if_lt_x_gpu(1)
@dist_init
def test_device_maps_missing_config_response(self):
self._test_device_maps_missing_config_response(RPCExecMode.SYNC)
@skip_if_lt_x_gpu(1)
@dist_init
def test_device_maps_missing_config_response_loop(self):
for _ in range(self.rpc_backend_options.num_worker_threads + 5):
self._test_device_maps_missing_config_response(RPCExecMode.SYNC)
@skip_if_lt_x_gpu(1)
@dist_init
def test_device_maps_missing_config_remote(self):
self._test_device_maps_missing_config(RPCExecMode.REMOTE)
@skip_if_lt_x_gpu(1)
@dist_init
def test_device_maps_missing_config_remote_response(self):
self._test_device_maps_missing_config_response(RPCExecMode.REMOTE)
@skip_if_lt_x_gpu(2)
def test_device_maps_remote(self):
options = self.rpc_backend_options
dst = worker_name((self.rank + 1) % self.world_size)
options.set_device_map(dst, {1: 0})
rpc.init_rpc(
name=worker_name(self.rank),
backend=self.rpc_backend,
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=options,
)
rref = rpc.remote(
dst,
TensorPipeAgentRpcTest._add_to_gpu,
args=(torch.zeros(2), 1)
)
self.assertEqual(rref.to_here(), torch.ones(2).to(1))