mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/38748 This diff contains the message scaffolding and profiler changes in order to be able to remotely run the profiler across different nodes and aggregate the results on a single node. As discussed, we have implemented this by creating new message types, that similar to autograd messages, wrap the profiling information with the original message, and send this new message over the wire. On the receiving end, this wrapped message is detected, we fetch the original message from it, and process the original message with the profiler enabled. When sending a response with profiling information, we serialize the profiled `Events` and send them back over RPC. When such a message is received, the events profiled on the remote node are stored (added back to the local profiler). Changes in this PR: - New message types (run_with_profiling_req, run_with_profiling_resp) to send profiling info over the wire. Message parsing logic is added to handle these wrapped types. - Handling of sending profiler data over the wire, in particular, the attributes of the `ProfilerConfig` and the serialized profiled `Event`s - The logic for wrapping RPC messages is deduped with that in `rpc_with_autograd`, and the common payload wrapping/unwrapping logic is moved to helper functions in `rpc/utils.cpp` - Changes in `autograd/utils.cpp` to detect if we have enabled the profiler and are sending an RPC, if so, uses the above new message types - Changes in request_callback to parse and turn on the profiler in a thread-local fashion - Serialization and deserialization of profiling `Events`, and support to add the remote events to the thread-local profiler - Introduction of the concept of `node_id`, which as discussed with ilia-cher , will be used along with the `Event`s handle attribute to distinguish between events. When there are events from different nodes, this node information is rendered in the profile output (e.g. when printing tables), otherwise, it is not, since it is irrelevant. - Some changes to profiler.cpp to add useful helper methods/guards - toHere() is now profiled for RRefs - Unittests ghstack-source-id: 106134626 Test Plan: Added unittests, existing profiler unittests. Differential Revision: D19510010 fbshipit-source-id: 044347af992f19a9e3b357c9567f6fc73e988157
648 lines
22 KiB
C++
648 lines
22 KiB
C++
#include <torch/csrc/distributed/rpc/utils.h>
|
|
|
|
#include <torch/csrc/autograd/profiler.h>
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h>
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.h>
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.h>
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_req.h>
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h>
|
|
#include <torch/csrc/distributed/autograd/utils.h>
|
|
#include <torch/csrc/distributed/rpc/python_call.h>
|
|
#include <torch/csrc/distributed/rpc/python_remote_call.h>
|
|
#include <torch/csrc/distributed/rpc/python_resp.h>
|
|
#include <torch/csrc/distributed/rpc/rref_proto.h>
|
|
#include <torch/csrc/distributed/rpc/script_call.h>
|
|
#include <torch/csrc/distributed/rpc/script_remote_call.h>
|
|
#include <torch/csrc/distributed/rpc/script_resp.h>
|
|
#include <torch/csrc/jit/serialization/pickler.h>
|
|
#include <torch/csrc/jit/serialization/unpickler.h>
|
|
|
|
#include <fmt/format.h>
|
|
|
|
namespace torch {
|
|
namespace distributed {
|
|
namespace rpc {
|
|
|
|
const std::string kRPCErrorPrefix = std::string("RPCErr");
|
|
|
|
RPCErrorType getRPCErrorType(const FutureMessage& fm) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
fm.hasError(),
|
|
"FutureMessage passed to getRPCErrorType does not have an error.");
|
|
|
|
// Attempt to parse for error string given by makeRPCError, otherwise return
|
|
// unknown error.
|
|
// Note that this function expects errors formatted with makeRPCError().
|
|
auto err = std::string(fm.error()->what());
|
|
size_t pos = err.find(kRPCErrorPrefix);
|
|
if (pos != std::string::npos) {
|
|
// Parse the RPCErrorType.
|
|
auto errStartIdx =
|
|
pos + torch::distributed::rpc::kRPCErrorPrefix.size() + 1;
|
|
auto errEndIdx = err.find(':', errStartIdx);
|
|
if (errEndIdx == std::string::npos) {
|
|
// Indicates error was not formatted correctly.
|
|
return RPCErrorType::UNKNOWN_ERROR;
|
|
}
|
|
auto errStr = err.substr(errStartIdx, errEndIdx - errStartIdx);
|
|
auto errType = static_cast<RPCErrorType>(std::stoi(errStr));
|
|
return errType;
|
|
} else {
|
|
return RPCErrorType::UNKNOWN_ERROR;
|
|
}
|
|
}
|
|
|
|
std::string makeRPCError(
|
|
const std::string& rpcErrorStr,
|
|
RPCErrorType errorType) {
|
|
return fmt::format(
|
|
"{}:{}:{}",
|
|
torch::distributed::rpc::kRPCErrorPrefix,
|
|
errorType,
|
|
rpcErrorStr);
|
|
}
|
|
|
|
std::unique_ptr<RpcCommandBase> deserializeRequest(const Message& request) {
|
|
switch (request.type()) {
|
|
case MessageType::SCRIPT_CALL: {
|
|
return ScriptCall::fromMessage(request);
|
|
}
|
|
case MessageType::PYTHON_CALL: {
|
|
return PythonCall::fromMessage(request);
|
|
}
|
|
case MessageType::SCRIPT_REMOTE_CALL: {
|
|
return ScriptRemoteCall::fromMessage(request);
|
|
}
|
|
case MessageType::PYTHON_REMOTE_CALL: {
|
|
return PythonRemoteCall::fromMessage(request);
|
|
}
|
|
case MessageType::SCRIPT_RREF_FETCH_CALL: {
|
|
return ScriptRRefFetchCall::fromMessage(request);
|
|
}
|
|
case MessageType::PYTHON_RREF_FETCH_CALL: {
|
|
return PythonRRefFetchCall::fromMessage(request);
|
|
}
|
|
case MessageType::RREF_USER_DELETE: {
|
|
return RRefUserDelete::fromMessage(request);
|
|
}
|
|
case MessageType::RREF_CHILD_ACCEPT: {
|
|
return RRefChildAccept::fromMessage(request);
|
|
}
|
|
case MessageType::RREF_FORK_REQUEST: {
|
|
return RRefForkRequest::fromMessage(request);
|
|
}
|
|
case MessageType::FORWARD_AUTOGRAD_REQ: {
|
|
return autograd::RpcWithAutograd::fromMessage(request);
|
|
}
|
|
case MessageType::BACKWARD_AUTOGRAD_REQ: {
|
|
return autograd::PropagateGradientsReq::fromMessage(request);
|
|
}
|
|
case MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ: {
|
|
return autograd::CleanupAutogradContextReq::fromMessage(request);
|
|
}
|
|
case MessageType::RUN_WITH_PROFILING_REQ: {
|
|
return autograd::RpcWithProfilingReq::fromMessage(request);
|
|
}
|
|
default: {
|
|
TORCH_INTERNAL_ASSERT(
|
|
false, "Request type ", request.type(), " not supported.");
|
|
}
|
|
}
|
|
}
|
|
|
|
std::unique_ptr<RpcCommandBase> deserializeResponse(
|
|
const Message& response,
|
|
MessageType& wrappedMsgType) {
|
|
switch (response.type()) {
|
|
case MessageType::SCRIPT_RET: {
|
|
return ScriptResp::fromMessage(response);
|
|
}
|
|
case MessageType::PYTHON_RET: {
|
|
return PythonResp::fromMessage(response);
|
|
}
|
|
case MessageType::REMOTE_RET: {
|
|
return RemoteRet::fromMessage(response);
|
|
}
|
|
case MessageType::SCRIPT_RREF_FETCH_RET: {
|
|
return ScriptRRefFetchRet::fromMessage(response);
|
|
}
|
|
case MessageType::PYTHON_RREF_FETCH_RET: {
|
|
return PythonRRefFetchRet::fromMessage(response);
|
|
}
|
|
case MessageType::RREF_ACK: {
|
|
return RRefAck::fromMessage(response);
|
|
}
|
|
case MessageType::FORWARD_AUTOGRAD_RESP: {
|
|
std::unique_ptr<RpcCommandBase> rpcPtr =
|
|
autograd::RpcWithAutograd::fromMessage(response);
|
|
RpcCommandBase& rpc = *rpcPtr;
|
|
auto& rpcWithAutograd = static_cast<autograd::RpcWithAutograd&>(rpc);
|
|
|
|
// Attach 'recv' autograd function.
|
|
addRecvRpcBackward(
|
|
rpcWithAutograd.autogradMetadata(),
|
|
rpcWithAutograd.tensors(),
|
|
rpcWithAutograd.fromWorkerId());
|
|
|
|
wrappedMsgType = rpcWithAutograd.wrappedMessageType();
|
|
|
|
return std::move(rpcWithAutograd).moveWrappedRpc();
|
|
}
|
|
case MessageType::BACKWARD_AUTOGRAD_RESP: {
|
|
return autograd::PropagateGradientsResp::fromMessage(response);
|
|
}
|
|
case MessageType::CLEANUP_AUTOGRAD_CONTEXT_RESP: {
|
|
return autograd::CleanupAutogradContextResp::fromMessage(response);
|
|
}
|
|
case MessageType::RUN_WITH_PROFILING_RESP: {
|
|
std::unique_ptr<RpcCommandBase> rpcPtr =
|
|
autograd::RpcWithProfilingResp::fromMessage(response);
|
|
RpcCommandBase& rpc = *rpcPtr;
|
|
auto& rpcWithProfilingResp =
|
|
static_cast<autograd::RpcWithProfilingResp&>(rpc);
|
|
// read and reconstruct events
|
|
// Check if the profiler is enabled
|
|
auto enabled = torch::autograd::profiler::profilerEnabled();
|
|
TORCH_CHECK(
|
|
enabled,
|
|
"Profiler was expected to be enabled. This can happen in callback "
|
|
" continutations that run in different threads, and the TLS of the "
|
|
" profiler was not propagated.");
|
|
std::vector<torch::autograd::profiler::Event> events =
|
|
rpcWithProfilingResp.getProfiledEvents();
|
|
wrappedMsgType = rpcWithProfilingResp.wrappedMessageType();
|
|
|
|
torch::autograd::profiler::addEventList(std::move(events));
|
|
|
|
auto wrappedRPC = std::move(rpcWithProfilingResp).moveWrappedRpc();
|
|
return wrappedRPC;
|
|
}
|
|
default: {
|
|
TORCH_INTERNAL_ASSERT(
|
|
false, "Response type ", response.type(), " not supported.");
|
|
}
|
|
}
|
|
}
|
|
|
|
IValue deserializeResptoIValueInternal(
|
|
RpcCommandBase& rpc,
|
|
MessageType messageType) {
|
|
switch (messageType) {
|
|
case MessageType::SCRIPT_RET: {
|
|
auto& ret = static_cast<ScriptResp&>(rpc);
|
|
return ret.value();
|
|
}
|
|
default: {
|
|
TORCH_INTERNAL_ASSERT(
|
|
false,
|
|
"Response type ",
|
|
messageType,
|
|
" is not supported to be deserialized to IValue.");
|
|
}
|
|
}
|
|
}
|
|
|
|
IValue deserializeRespToIValue(const Message& message) {
|
|
MessageType msgType = message.type();
|
|
auto response = deserializeResponse(message, msgType);
|
|
return deserializeResptoIValueInternal(*response, msgType);
|
|
}
|
|
|
|
namespace {
|
|
|
|
// Helper for wireDeserialize() below.
|
|
//
|
|
// The format we use below looks like:
|
|
// section_name_1 size_1\n
|
|
// section_name_2 size_2\n
|
|
// ..
|
|
// \n
|
|
// [sections in order]
|
|
//
|
|
// Sections themselves include:
|
|
// - "payload" - the payload bits
|
|
// - "meta" - metadata for the unpickler
|
|
// - "0" ... - tensor sections for the unpickler
|
|
//
|
|
// Note that per the header comments, the format is subject to change,
|
|
// and is best used for rpcs, rather than persistent disk storage.
|
|
std::unordered_map<std::string, std::pair<const char*, size_t>>
|
|
parseWireSections(const void* data, size_t data_size) {
|
|
const char* ptr = static_cast<const char*>(data);
|
|
const char* endp = ptr + data_size;
|
|
|
|
std::vector<std::pair<std::string, size_t>> headerEnts;
|
|
bool ok = false;
|
|
while (ptr != endp) {
|
|
if (*ptr == '\n') {
|
|
ok = true; // The only "correct" exit point.
|
|
++ptr;
|
|
break;
|
|
}
|
|
// Parse name
|
|
const char* namePtr = ptr;
|
|
while (ptr != endp && *ptr != ' ') {
|
|
ptr++;
|
|
}
|
|
if (ptr == endp) {
|
|
break;
|
|
}
|
|
std::string name(namePtr, ptr - namePtr);
|
|
if (++ptr == endp) {
|
|
break; // past the ' '
|
|
}
|
|
// Parse size
|
|
const char* sizePtr = ptr;
|
|
while (ptr != endp && *ptr != '\n') {
|
|
ptr++;
|
|
}
|
|
if (ptr == endp) {
|
|
break;
|
|
}
|
|
size_t sz = c10::stoll(std::string(sizePtr, ptr - sizePtr));
|
|
headerEnts.emplace_back(std::make_pair(name, sz));
|
|
++ptr; // past the '\n'
|
|
}
|
|
if (!ok) {
|
|
throw std::runtime_error("failed parse");
|
|
}
|
|
|
|
std::unordered_map<std::string, std::pair<const char*, size_t>> out;
|
|
for (const auto& headerEnt : headerEnts) {
|
|
out[headerEnt.first] = {ptr, headerEnt.second};
|
|
ptr += headerEnt.second;
|
|
}
|
|
if (ptr != endp) {
|
|
throw std::runtime_error("failed bounds");
|
|
}
|
|
return out;
|
|
}
|
|
|
|
static const char* kMeta = "meta";
|
|
static const char* kPayload = "payload";
|
|
}; // namespace
|
|
|
|
c10::List<at::Tensor> cloneSparseTensors(
|
|
const std::vector<at::Tensor>& tensors) {
|
|
// Sanity-check: If the majority of bits don't need to go over the wire,
|
|
// force a clone(). Some Tensors are effectively small views, only using
|
|
// ~1% of the underlying Storage.
|
|
auto worthRecopying = [](const at::Tensor& t) -> bool {
|
|
if (!t.has_storage()) {
|
|
return false; // avoid throwing below.
|
|
}
|
|
auto storageSize = t.storage().nbytes();
|
|
auto usefulSize = t.element_size() * t.numel();
|
|
constexpr size_t kMinMultiple = 2;
|
|
constexpr size_t kMinRecopyBytes = 8 * 1024;
|
|
return storageSize >= kMinRecopyBytes &&
|
|
storageSize >= usefulSize * kMinMultiple;
|
|
};
|
|
c10::List<at::Tensor> pTensors;
|
|
pTensors.reserve(tensors.size());
|
|
for (const auto& t : tensors) {
|
|
pTensors.push_back(worthRecopying(t) ? t.clone() : t);
|
|
}
|
|
return pTensors;
|
|
}
|
|
|
|
std::string wireSerialize(
|
|
const std::vector<char>& payload,
|
|
const std::vector<at::Tensor>& tensors) {
|
|
for (const auto& tensor : tensors) {
|
|
TORCH_CHECK(
|
|
tensor.device().is_cpu(),
|
|
"ProcessGroup RPC backend only supports",
|
|
" CPU tensors, please move your tensors to CPU before sending ",
|
|
"them over RPC. Found tensor on device: ",
|
|
tensor.device());
|
|
}
|
|
|
|
struct Ent {
|
|
std::string name;
|
|
const char* data;
|
|
size_t size;
|
|
};
|
|
std::vector<Ent> entries;
|
|
std::string metaEntry;
|
|
std::vector<at::Tensor> tensorData;
|
|
|
|
if (!payload.empty()) {
|
|
entries.push_back({kPayload, payload.data(), payload.size()});
|
|
}
|
|
|
|
if (!tensors.empty()) {
|
|
torch::jit::Pickler pickler([&](const void* buf, size_t sz) -> size_t {
|
|
metaEntry.append(static_cast<const char*>(buf), sz);
|
|
return sz;
|
|
});
|
|
pickler.protocol();
|
|
pickler.pushIValue(cloneSparseTensors(tensors));
|
|
pickler.stop();
|
|
tensorData = pickler.tensorData();
|
|
entries.push_back({kMeta, metaEntry.data(), metaEntry.size()});
|
|
for (size_t i = 0; i < tensorData.size(); i++) {
|
|
// Construct WritableTensorData for each tensor in the pickler tensorData
|
|
// Since tensorData is in function scope, and getWritableTensorData just
|
|
// record the tensors, the data() pointers stay valid for CPU tensors
|
|
// Note that RPC serde doesn't support CUDA tensors yet, if we should
|
|
// support CUDA tensor, we need to be careful since getWritableTensorData
|
|
// converts CUDA tensor to cpu and data() might get destructed as we go
|
|
// out of scope of this loop.
|
|
auto writeableTensorData = jit::getWriteableTensorData(tensorData[i]);
|
|
entries.push_back({c10::to_string(i),
|
|
writeableTensorData.data(),
|
|
writeableTensorData.sizeInBytes()});
|
|
}
|
|
}
|
|
|
|
std::string header;
|
|
size_t tot = 0;
|
|
for (const auto& e : entries) {
|
|
tot += e.size;
|
|
header.append(e.name)
|
|
.append(" ")
|
|
.append(c10::to_string(e.size))
|
|
.append("\n");
|
|
}
|
|
header.push_back('\n');
|
|
|
|
std::string out;
|
|
out.reserve(header.size() + tot);
|
|
out.append(header);
|
|
for (const auto& e : entries) {
|
|
out.append(e.data, e.size);
|
|
}
|
|
return out;
|
|
}
|
|
|
|
std::pair<std::vector<char>, std::vector<at::Tensor>> wireDeserialize(
|
|
const void* data,
|
|
size_t data_size) {
|
|
auto sections = parseWireSections(data, data_size);
|
|
|
|
std::vector<char> payload;
|
|
auto payloadIt = sections.find(kPayload);
|
|
if (payloadIt != sections.end() && payloadIt->second.second != 0) {
|
|
payload.assign(
|
|
payloadIt->second.first,
|
|
payloadIt->second.first + payloadIt->second.second);
|
|
}
|
|
|
|
std::vector<at::Tensor> tensors;
|
|
auto metaIt = sections.find(kMeta);
|
|
if (metaIt != sections.end()) {
|
|
const auto& metaData = metaIt->second;
|
|
size_t metaDataPos = 0;
|
|
auto metaDataReadFunc = [&](char* buf, size_t n) -> size_t {
|
|
if (metaDataPos >= metaData.second || n == 0) {
|
|
return 0;
|
|
}
|
|
size_t toCopy = std::min(metaDataPos + n, metaData.second) - metaDataPos;
|
|
memcpy(buf, metaData.first + metaDataPos, toCopy);
|
|
metaDataPos += toCopy;
|
|
return toCopy;
|
|
};
|
|
auto sectionReadFunc = [&](const std::string& ename) -> at::DataPtr {
|
|
auto it = sections.find(ename);
|
|
if (it == sections.end()) {
|
|
throw std::runtime_error("Couldn't find entity " + ename);
|
|
}
|
|
const auto& idat = it->second;
|
|
auto dptr = at::getCPUAllocator()->allocate(idat.second);
|
|
if (idat.second != 0) {
|
|
memcpy(dptr.get(), idat.first, idat.second);
|
|
}
|
|
return dptr;
|
|
};
|
|
|
|
// No need to pass typeResolver here, as it always processes string and
|
|
// tensors only
|
|
torch::jit::Unpickler unpickler(
|
|
metaDataReadFunc, nullptr, nullptr, sectionReadFunc, {});
|
|
auto ival = unpickler.parse_ivalue();
|
|
for (auto&& t : ival.toTensorList()) {
|
|
tensors.emplace_back(std::move(t));
|
|
}
|
|
}
|
|
return {std::move(payload), std::move(tensors)};
|
|
}
|
|
|
|
namespace {
|
|
|
|
// The TensorPipe agent splits the RPC message's information across multiple
|
|
// payloads. This allows the agent to provide the data to TensorPipe without
|
|
// performing a copy into a single contiguous buffer, and without storing it as
|
|
// metadata, which is less efficient.
|
|
|
|
// First come the rpc::Message::type() and ::id().
|
|
constexpr int kTpMessageTypeIdx = 0;
|
|
constexpr int kTpMessageIdIdx = 1;
|
|
// Then comes the rpc::Message::payload();
|
|
constexpr int kTpMessagePayloadIdx = 2;
|
|
// Last comes the pickle of rpc::Message::tensors() (with the tensors themselves
|
|
// stored as, well, tensors in the tensorpipe::Message).
|
|
constexpr int kTpMessagePickleIdx = 3;
|
|
|
|
} // namespace
|
|
|
|
std::tuple<tensorpipe::Message, TensorpipeWriteBuffers> tensorpipeSerialize(
|
|
Message&& rpcMessage) {
|
|
tensorpipe::Message tpMessage;
|
|
TensorpipeWriteBuffers buffers;
|
|
|
|
// Metadata
|
|
buffers.type = std::make_unique<MessageType>(rpcMessage.type());
|
|
buffers.id = std::make_unique<int64_t>(rpcMessage.id());
|
|
tpMessage.payloads.push_back(
|
|
tensorpipe::Message::Payload{buffers.type.get(), sizeof(MessageType)});
|
|
tpMessage.payloads.push_back(
|
|
tensorpipe::Message::Payload{buffers.id.get(), sizeof(int64_t)});
|
|
|
|
// Payload
|
|
buffers.payload = std::move(rpcMessage.payload());
|
|
// TensorPipe uses the same Message class for both reading and writing, thus
|
|
// it uses non-const pointers even though it doesn't modify them when writing.
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
|
char* payloadPtr = const_cast<char*>(buffers.payload.data());
|
|
tpMessage.payloads.push_back(
|
|
tensorpipe::Message::Payload{payloadPtr, buffers.payload.size()});
|
|
|
|
// Tensors
|
|
buffers.tensors = cloneSparseTensors(rpcMessage.tensors()).vec();
|
|
torch::jit::Pickler pickler([&](const void* buf, size_t sz) -> size_t {
|
|
buffers.pickle.insert(
|
|
buffers.pickle.end(),
|
|
static_cast<const char*>(buf),
|
|
static_cast<const char*>(buf) + sz);
|
|
return sz;
|
|
});
|
|
pickler.protocol();
|
|
pickler.pushIValue(buffers.tensors);
|
|
pickler.stop();
|
|
tpMessage.payloads.push_back(tensorpipe::Message::Payload{
|
|
buffers.pickle.data(), buffers.pickle.size()});
|
|
for (const auto& tensor : pickler.tensorData()) {
|
|
const auto& tensorData = jit::getWriteableTensorData(tensor);
|
|
// Enforce memory copy if tensor is created from torch::from_blob, means
|
|
// that the tensor doesn't own the memory.
|
|
if (!tensorData.storageHasDeleter()) {
|
|
std::vector<char> storageData(
|
|
tensorData.data(), tensorData.data() + tensorData.sizeInBytes());
|
|
tpMessage.tensors.push_back(
|
|
tensorpipe::Message::Tensor{storageData.data(), storageData.size()});
|
|
buffers.copiedTensors.push_back(std::move(storageData));
|
|
} else {
|
|
// TensorPipe uses the same Message class for both reading and writing, so
|
|
// it uses non-const ptrs even though it doesn't modify them when writing.
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
|
char* tensorPtr = const_cast<char*>(tensorData.data());
|
|
tpMessage.tensors.push_back(
|
|
tensorpipe::Message::Tensor{tensorPtr, tensorData.sizeInBytes()});
|
|
}
|
|
}
|
|
|
|
return std::make_tuple(std::move(tpMessage), std::move(buffers));
|
|
}
|
|
|
|
TensorpipeReadBuffers tensorpipeAllocate(tensorpipe::Message& tpMessage) {
|
|
TensorpipeReadBuffers buffers;
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
tpMessage.payloads.size() == 4,
|
|
"message expected to contain 4 payloads, whereas it contained ",
|
|
tpMessage.payloads.size(),
|
|
" payloads");
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
tpMessage.payloads[kTpMessageTypeIdx].length == sizeof(MessageType),
|
|
"first payload expected to contain ",
|
|
sizeof(MessageType),
|
|
" bytes, whereas it contained ",
|
|
tpMessage.payloads[kTpMessageTypeIdx].length,
|
|
" bytes");
|
|
buffers.type = std::make_unique<MessageType>();
|
|
tpMessage.payloads[kTpMessageTypeIdx].data = buffers.type.get();
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
tpMessage.payloads[kTpMessageIdIdx].length == sizeof(int64_t),
|
|
"second payload expected to contain ",
|
|
sizeof(int64_t),
|
|
" bytes, whereas it contained ",
|
|
tpMessage.payloads[kTpMessageIdIdx].length,
|
|
" bytes");
|
|
buffers.id = std::make_unique<int64_t>();
|
|
tpMessage.payloads[kTpMessageIdIdx].data = buffers.id.get();
|
|
|
|
// FIXME The two resizes below zero out the vectors, which is not needed.
|
|
|
|
buffers.payload.resize(tpMessage.payloads[kTpMessagePayloadIdx].length);
|
|
tpMessage.payloads[kTpMessagePayloadIdx].data = buffers.payload.data();
|
|
|
|
buffers.pickle.resize(tpMessage.payloads[kTpMessagePickleIdx].length);
|
|
tpMessage.payloads[kTpMessagePickleIdx].data = buffers.pickle.data();
|
|
|
|
for (auto& tensor : tpMessage.tensors) {
|
|
buffers.tensors.push_back(at::getCPUAllocator()->allocate(tensor.length));
|
|
tensor.data = buffers.tensors.back().get();
|
|
}
|
|
|
|
return buffers;
|
|
}
|
|
|
|
Message tensorpipeDeserialize(
|
|
tensorpipe::Message&& message,
|
|
TensorpipeReadBuffers&& buffers) {
|
|
// Tensors
|
|
std::vector<at::Tensor> tensors;
|
|
const char* pickleData = buffers.pickle.data();
|
|
size_t pickleLen = buffers.pickle.size();
|
|
size_t picklePos = 0;
|
|
auto pickleReadFunc = [&](char* buf, size_t n) -> size_t {
|
|
if (picklePos >= pickleLen || n == 0) {
|
|
return 0;
|
|
}
|
|
size_t toCopy = std::min(picklePos + n, pickleLen) - picklePos;
|
|
memcpy(buf, pickleData + picklePos, toCopy);
|
|
picklePos += toCopy;
|
|
return toCopy;
|
|
};
|
|
auto tensorReadFunc = [&](const std::string& ename) -> at::DataPtr {
|
|
unsigned long index = std::stoul(ename);
|
|
return std::move(buffers.tensors.at(index));
|
|
};
|
|
|
|
// No need to pass typeResolver here, as it always processes string and
|
|
// tensors only
|
|
torch::jit::Unpickler unpickler(
|
|
pickleReadFunc, nullptr, nullptr, tensorReadFunc, {});
|
|
auto ival = unpickler.parse_ivalue();
|
|
for (auto&& t : ival.toTensorList()) {
|
|
tensors.emplace_back(std::move(t));
|
|
}
|
|
|
|
return Message(
|
|
std::move(buffers.payload),
|
|
std::move(tensors),
|
|
*buffers.type,
|
|
*buffers.id);
|
|
}
|
|
|
|
void writeWrappedPayload(
|
|
std::vector<char>& originalPayload,
|
|
std::vector<char>& additionalPayload) {
|
|
originalPayload.insert(
|
|
originalPayload.end(),
|
|
additionalPayload.begin(),
|
|
additionalPayload.end());
|
|
|
|
// Add size of the additional pyaload
|
|
int64_t indexToWrite = originalPayload.size();
|
|
originalPayload.resize(originalPayload.size() + sizeof(int64_t));
|
|
const int64_t additionalPayloadSize = additionalPayload.size();
|
|
torch::utils::THP_encodeInt64Buffer(
|
|
reinterpret_cast<uint8_t*>(originalPayload.data()) + indexToWrite,
|
|
&additionalPayloadSize,
|
|
torch::utils::THPByteOrder::THP_BIG_ENDIAN,
|
|
1);
|
|
}
|
|
|
|
std::vector<at::IValue> readWrappedPayload(
|
|
std::vector<char>& payload,
|
|
const rpc::Message& message) {
|
|
// Read the additional payload remove it from the payload.
|
|
int64_t additionalPayloadSize;
|
|
size_t indexToRead = payload.size() - sizeof(int64_t);
|
|
TORCH_INTERNAL_ASSERT(indexToRead >= 0);
|
|
torch::utils::THP_decodeInt64Buffer(
|
|
&additionalPayloadSize,
|
|
reinterpret_cast<uint8_t*>(payload.data()) + indexToRead,
|
|
torch::utils::THPByteOrder::THP_BIG_ENDIAN,
|
|
1);
|
|
payload.resize(indexToRead);
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
payload.size() > additionalPayloadSize,
|
|
"Wrong payload sizes: payload.size() is ",
|
|
payload.size(),
|
|
" but additional payload size is ",
|
|
additionalPayloadSize);
|
|
auto wrappedPayloadBegin =
|
|
static_cast<const char*>(message.payload().data()) + payload.size() -
|
|
additionalPayloadSize;
|
|
std::vector<torch::Tensor> tensorTable;
|
|
IValue tuple = jit::unpickle(
|
|
wrappedPayloadBegin,
|
|
additionalPayloadSize,
|
|
*rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
|
|
&tensorTable);
|
|
std::vector<at::IValue> tupleElements = tuple.toTuple()->elements();
|
|
payload.resize(payload.size() - additionalPayloadSize);
|
|
return tupleElements;
|
|
}
|
|
} // namespace rpc
|
|
} // namespace distributed
|
|
} // namespace torch
|