pytorch/torch/csrc/distributed/rpc/python_functions.h
Luca Wehrstedt 45012da298 Migrate from shared_ptr to intrusive_ptr for Future (#57636)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57636

The "preferred" pointer holder for Future is `intrusive_ptr` (e.g., `then` returns an `intrusive_ptr`, `toFuture` returns `intrusive_ptr`, ...). However in RPC we often wrap it with `shared_ptr`. This probably dates back to when we had a separate Future type, before the merge.

At the boundary between RPC and JIT this difference becomes a bit annoying, as conversions between the pointer types are needed. I think it would be simpler and more consistent to always use `intrusive_ptr`, also in RPC.

This PR was produced mainly by find-and-replace, plus a couple of manual fixes.
ghstack-source-id: 128296581

Test Plan: CI

Reviewed By: pritamdamania87

Differential Revision: D28187972

fbshipit-source-id: d4609273a1550b4921910e85d2198e02f31c905b
2021-05-07 03:59:20 -07:00

71 lines
2.3 KiB
C++

#pragma once
#include <torch/csrc/distributed/rpc/py_rref.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/utils/pybind.h>
namespace torch {
namespace distributed {
namespace rpc {
// Converts an internal ivalue::Future of Message into a user-facing
// ivalue::Future of py::object type by creating a new ivalue::Future and call
// its markCompleted as a callback in the given ivalue::Future.
// If hasValue is true, the Message will be converted into a py::object and then
// wrap it with an IValue. If hasValue is false, this ivalue::Future is only
// used for signaling and launching callbacks. In this case, the message will be
// discarded and then set the ivalue::Future using an empty IValue or the given
// FutureError if there is an error.
c10::intrusive_ptr<JitFuture> toPyJitFuture(
const c10::intrusive_ptr<JitFuture>& messageJitFuture,
bool hasValue = true);
c10::intrusive_ptr<JitFuture> pyRpcBuiltin(
const WorkerInfo& dst,
const std::string& opName,
const py::args& args,
const py::kwargs& kwargs,
const float rpcTimeoutSeconds);
c10::intrusive_ptr<JitFuture> pyRpcPythonUdf(
const WorkerInfo& dst,
std::string& pickledPythonUDF,
std::vector<torch::Tensor>& tensors,
const float rpcTimeoutSeconds,
const bool isAsyncExecution);
c10::intrusive_ptr<JitFuture> pyRpcTorchscript(
const std::string& dstWorkerName,
const std::string& qualifiedNameStr,
const py::tuple& argsTuple,
const py::dict& kwargsDict,
const float rpcTimeoutSeconds,
const bool isAsyncExecution);
PyRRef pyRemoteBuiltin(
const WorkerInfo& dst,
const std::string& opName,
const float rpcTimeoutSeconds,
const py::args& args,
const py::kwargs& kwargs);
PyRRef pyRemotePythonUdf(
const WorkerInfo& dst,
std::string& pickledPythonUDF,
std::vector<torch::Tensor>& tensors,
const float rpcTimeoutSeconds,
const bool isAsyncExecution);
PyRRef pyRemoteTorchscript(
const std::string& dstWorkerName,
const std::string& qualifiedNameStr,
const float rpcTimeoutSeconds,
const bool isAsyncExecution,
const py::args& args,
const py::kwargs& kwargs);
} // namespace rpc
} // namespace distributed
} // namespace torch