mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48840 The CUDAFuture class needs to inspect the values it contains in order to extract its tensors (in fact, the DataPtrs backing those). These are needed first to determine what CUDA devices back those tensors, so that an event for each such device can be recorded; and later to record these DataPtrs with the CUDA caching allocator if they are used in other streams. This became complicated when Python was added to the mix, because to inspect a Python object we need to acquire the GIL, but we couldn't do so from code that was supposed to also work in C++-only mode. The solution was for users to provide a custom way to extract DataPtrs, so that the PythonFutureWrapper could install such a custom Python-aware one. This was the DataPtr extractor. In https://github.com/pytorch/pytorch/pull/48502 a different suggestion was proposed. At its root, it consists in adding support for IValues of type PyObject to the visit() and getSubValues() methods. In order to deal with the GIL, we do this through a virtual method: PyObjectHolder, which is the base class, is available also in C++-only mode, and thus defines this method but leaves it unimplemented; ConcretePyObjectHolder, which is the subclass, is only included in Python mode, and thus it can implement that method, acquire the GIL, and do what it's supposed to. In my opinion, this approach is just brilliant! Thank wanchaol for proposing it! It hides the complexity of dealing with Python inside getSubValues(), where it can be done properly, thus simplifying enormously the CUDAFuture and the PythonFutureWrapper classes. ghstack-source-id: 118704935 Test Plan: Unit tests Reviewed By: wanchaol Differential Revision: D25334355 fbshipit-source-id: 3f1d3bf6e6e8505a114c877fb9a6fcc3f68d91d3
77 lines
2.3 KiB
C++
77 lines
2.3 KiB
C++
#pragma once
|
|
#include <ATen/core/ivalue.h>
|
|
#include <pybind11/pybind11.h>
|
|
#include <torch/csrc/jit/python/pybind_utils.h>
|
|
#include <torch/csrc/python_headers.h>
|
|
|
|
namespace py = pybind11;
|
|
|
|
namespace c10 {
|
|
namespace ivalue {
|
|
|
|
// concrete ivalue Holder that hold a py::object
|
|
struct C10_EXPORT ConcretePyObjectHolder final : PyObjectHolder {
|
|
public:
|
|
static c10::intrusive_ptr<PyObjectHolder> create(py::object py_obj) {
|
|
return c10::make_intrusive<ConcretePyObjectHolder>(std::move(py_obj));
|
|
}
|
|
|
|
static c10::intrusive_ptr<PyObjectHolder> create(const py::handle& handle) {
|
|
py::gil_scoped_acquire ag;
|
|
return c10::make_intrusive<ConcretePyObjectHolder>(
|
|
handle.cast<py::object>());
|
|
}
|
|
|
|
PyObject* getPyObject() override {
|
|
return py_obj_.ptr();
|
|
}
|
|
|
|
InferredType tryToInferType() override {
|
|
pybind11::gil_scoped_acquire ag;
|
|
return torch::jit::tryToInferType(py_obj_);
|
|
}
|
|
|
|
IValue toIValue(const TypePtr& type, c10::optional<int32_t> N = c10::nullopt)
|
|
override {
|
|
pybind11::gil_scoped_acquire ag;
|
|
return torch::jit::toIValue(py_obj_, type, N);
|
|
}
|
|
|
|
std::string toStr() override {
|
|
pybind11::gil_scoped_acquire ag;
|
|
return py::str(py_obj_);
|
|
}
|
|
|
|
// Note [Destructing py::object]
|
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
//
|
|
// (1) Why py_obj_ = py::none(); does not work. Because we also need to
|
|
// acquire GIL when destructing py::object of None that de-references None.
|
|
// https://docs.python.org/3/c-api/none.html#c.Py_RETURN_NONE
|
|
//
|
|
// https://stackoverflow.com/questions/15287590/why-should-py-increfpy-none-be-required-before-returning-py-none-in-c
|
|
//
|
|
// (2) Why we need to call dec_ref() explicitly. Because py::object of
|
|
// nullptr, on destruction, effectively does nothing because of it calls
|
|
// Py_XDECREF(NULL) underlying.
|
|
// https://docs.python.org/3/c-api/refcounting.html#c.Py_XDECREF
|
|
~ConcretePyObjectHolder() {
|
|
pybind11::gil_scoped_acquire ag;
|
|
py_obj_.dec_ref();
|
|
// explicitly setting PyObject* to nullptr to prevent py::object's dtor to
|
|
// decref on the PyObject again.
|
|
py_obj_.ptr() = nullptr;
|
|
}
|
|
|
|
// explicit construction to avoid errornous implicit conversion and
|
|
// copy-initialization
|
|
explicit ConcretePyObjectHolder(py::object py_obj)
|
|
: py_obj_(std::move(py_obj)) {}
|
|
|
|
private:
|
|
py::object py_obj_;
|
|
};
|
|
|
|
} // namespace ivalue
|
|
} // namespace c10
|