Use at most one shared_ptr block at a time to manage THPFunctions (#1454)

* Fix failing ln in build_all.sh

* Use at most one shared_ptr block at a time to manage THPFunctions
This commit is contained in:
Adam Paszke 2017-05-03 14:15:36 +02:00 committed by Soumith Chintala
parent e1278d4ee2
commit 72e8190994
4 changed files with 28 additions and 4 deletions

View File

@ -270,6 +270,7 @@ static void THPFunction_dealloc(THPFunction* self)
{
PyObject_GC_UnTrack(self);
THPFunction_clear(self);
self->cdata_ptr.~weak_ptr();
self->cdata.~PyFunction();
Py_TYPE(self)->tp_free((PyObject*)self);
}
@ -282,6 +283,7 @@ PyObject *THPFunction_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
// most fields
THPFunction* self = (THPFunction*)obj;
new (&self->cdata) torch::autograd::PyFunction(obj);
new (&self->cdata_ptr) std::weak_ptr<torch::autograd::PyFunction>();
self->cdata.num_inputs = -1;
self->cdata.is_stochastic = PyObject_IsInstance(obj, THPStochasticFunctionClass);
return obj;
@ -998,11 +1000,29 @@ struct Decref {
}
};
// Similar to shared_from_this. There's a problem that the Python object
// and its cdata depend on each other being alive, so we can't keep
// shared_ptrs as members, but we'd like to be able to manage the lifetime of
// the objects using shared_ptrs in the C++ graph. The only way to get a new
// shared_ptr that references them is through THPFunction_asFunction. When
// called for the first time it will allocate a new shared_ptr and save a
// weak_ptr in cdata_ptr attr. Later, when we try to take another reference,
// we'll try to lock cdata_ptr and return its value if successful. Otherwise it
// means that all shared_ptrs returned previously have been freed, so we can
// create a new one. This ensures that this object is managed by at most one
// shared_ptr control block at any time - a guarantee we depend on in other places
// (e.g. we use weak_ptrs in SavedVariable because we know it won't go out of scope).
std::shared_ptr<PyFunction> THPFunction_asFunction(THPFunction* self)
{
if (!self) {
return std::shared_ptr<PyFunction>();
}
Py_INCREF((PyObject*)self);
return std::shared_ptr<PyFunction>(&self->cdata, Decref());
auto ptr = self->cdata_ptr.lock();
if (ptr) return ptr;
ptr = std::shared_ptr<PyFunction>(&self->cdata, Decref());
self->cdata_ptr = ptr;
return ptr;
}

View File

@ -43,6 +43,8 @@ struct THPFunction {
std::vector<bool> *is_variable_input;
char has_freed_buffers;
// See a comment in THPFucntion_asFunction for details about this field.
std::weak_ptr<torch::autograd::PyFunction> cdata_ptr;
torch::autograd::PyFunction cdata;
};

View File

@ -93,7 +93,7 @@ auto SavedVariable::unpack() -> std::shared_ptr<Variable> {
// should have saved the grad accumulator. Even if the Variable no longer
// alive, the accumulator should be kept alive by the references in the graph).
if (requires_grad && !grad_fn && weak_grad_fn.expired() && grad_accumulator.expired())
throw std::logic_error("No grad accumulator for a saved leaf!");
throw std::logic_error("No grad accumulator for a saved leaf!");
new_var->grad_accumulator = grad_accumulator;
return new_var;

View File

@ -28,7 +28,7 @@ BASIC_C_FLAGS=" -DTH_INDEX_BASE=0 -I$INSTALL_DIR/include \
LDFLAGS="-L$INSTALL_DIR/lib "
LD_POSTFIX=".so.1"
LD_POSTFIX_UNVERSIONED=".so"
if [[ $(uname) == 'Darwin' ]]; then
if [[ $(uname) == 'Darwin' ]]; then
LDFLAGS="$LDFLAGS -Qunused-arguments -Wl,-rpath,@loader_path"
LD_POSTFIX=".1.dylib"
LD_POSTFIX_UNVERSIONED=".dylib"
@ -93,7 +93,9 @@ function build_nccl() {
-DCMAKE_CXX_FLAGS="$C_FLAGS $CPP_FLAGS"
make install
cp "lib/libnccl.so.1" "${INSTALL_DIR}/lib/libnccl.so.1"
ln -s "${INSTALL_DIR}/lib/libnccl.so.1" "${INSTALL_DIR}/lib/libnccl.so"
if [ ! -f "${INSTALL_DIR}/lib/libnccl.so" ]; then
ln -s "${INSTALL_DIR}/lib/libnccl.so.1" "${INSTALL_DIR}/lib/libnccl.so"
fi
cd ../..
}