Fix grad_fn bindings when saved variable freed (#56499)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/54472

Adds HANDLE_TH_ERRORS to python bindings for grad_fn attrs and updates tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/56499

Reviewed By: albanD

Differential Revision: D27920742

Pulled By: soulitzer

fbshipit-source-id: d4f7ac8c0aa2173d25517277c393f8c66de68951
This commit is contained in:
Jeffrey Wan 2021-04-22 13:38:05 -07:00 committed by Facebook GitHub Bot
parent 679cc7eb13
commit 2128a84a69
4 changed files with 48 additions and 5 deletions

View File

@ -4585,6 +4585,11 @@ for shape in [(1,), ()]:
self.assertEqual(out.grad_fn._saved_dim, 0) # int64_t -> int
self.assertIsInstance(out.grad_fn._saved_dim, int)
out.sum().backward()
with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
out.grad_fn._saved_tensors
self.assertEqual(out.grad_fn._saved_dim, 0)
a = torch.ones(2, 2, requires_grad=True)
indices = torch.tensor([0, 1])
out = a[:, indices]
@ -4646,6 +4651,19 @@ for shape in [(1,), ()]:
out = torch.tanh(a)
self.assertEqual(out, out.grad_fn._saved_result) # saved variable when output
a = torch.randn(3, 5, requires_grad=True)
b = torch.tensor([1, 0, 4])
loss = nn.NLLLoss()
out = loss(a, b)
self.assertIsNone(out.grad_fn._saved_weight)
loss = nn.NLLLoss(weight=torch.ones((5,)))
out = loss(a, b)
self.assertEqual(out.grad_fn._saved_weight, torch.ones((5,))) # c10:optional<Tensor> -> Tensor?
out.sum().backward()
with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
out.grad_fn._saved_weight
def test_autograd_views_codegen(self):
# This is not necessarily the absolute correct behavior, but this is the current
# one. This test is here to make sure that any change to this behavior is detected

View File

@ -110,37 +110,59 @@ PY_GETSETDEF_STRUCT = CodeTemplate("""\
# Getter templates
GETTER_DEFINITION = CodeTemplate("""\
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
HANDLE_TH_ERRORS
auto prop = static_cast<${op}*>(self->cdata.get())->${name};
${body}
END_HANDLE_TH_ERRORS
}
""")
GETTER_DEFINITION_SAVEDVAR = CodeTemplate("""\
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
HANDLE_TH_ERRORS
const auto& prop = static_cast<${op}*>(self->cdata.get())->${name}_;
${body}
END_HANDLE_TH_ERRORS
}
""")
GETTER_DEFINITION_VEC_SAVEDVAR = CodeTemplate("""\
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
HANDLE_TH_ERRORS
const auto *node = static_cast<${op}*>(self->cdata.get());
const auto& prop = node->${name}_;
if (node->${name}_released_) {
PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE);
return nullptr;
}
${body}
END_HANDLE_TH_ERRORS
}
""")
GETTER_DEFINITION_OPT = CodeTemplate("""\
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
HANDLE_TH_ERRORS
auto opt_prop = static_cast<${op}*>(self->cdata.get())->${name};
if (!opt_prop.has_value()) {
Py_RETURN_NONE;
}
auto prop = opt_prop.value();
${body}
END_HANDLE_TH_ERRORS
}
""")
GETTER_DEFINITION_OPT_ARRAYREF = CodeTemplate("""\
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
HANDLE_TH_ERRORS
auto opt_prop = static_cast<${op}*>(self->cdata.get())->${name};
if (!opt_prop.list.has_value()) {
Py_RETURN_NONE;
}
auto prop = opt_prop.list.value();
${body}
END_HANDLE_TH_ERRORS
}
""")
@ -317,7 +339,7 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str
release_variables.append(f'{name}_released_ = true;')
unpack.append(f'auto {name} = unpack_list({name}_);')
asserts.append(f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);')
getter_definitions.append(GETTER_DEFINITION_SAVEDVAR.substitute(
getter_definitions.append(GETTER_DEFINITION_VEC_SAVEDVAR.substitute(
op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR))
elif type == ListCType(OptionalCType(BaseCType(tensorT))):
saved_variables.append(f'std::vector<SavedVariable> {name}_;')
@ -328,7 +350,7 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str
release_variables.append(f'{name}_released_ = true;')
unpack.append(f'auto {name} = unpack_opt_list({name}_);')
asserts.append(f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);')
getter_definitions.append(GETTER_DEFINITION_SAVEDVAR.substitute(
getter_definitions.append(GETTER_DEFINITION_VEC_SAVEDVAR.substitute(
op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR))
elif type == BaseCType(intArrayRefT):
saved_variables.append(f'std::vector<int64_t> {name};')

View File

@ -8,6 +8,7 @@
#include "torch/csrc/autograd/generated/Functions.h"
#include "torch/csrc/autograd/python_cpp_function.h"
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/autograd/saved_variable.h>
namespace torch { namespace autograd { namespace generated {

View File

@ -140,8 +140,10 @@ Variable SavedVariable::unpack(std::shared_ptr<Node> saved_for) const {
}
const char* ERR_BACKWARD_TWICE =
"Trying to backward through the graph a second time, but the saved intermediate "
"results have already been freed. Specify retain_graph=True when calling "
".backward() or autograd.grad() the first time.";
"Trying to backward through the graph a second time (or directly access saved "
"variables after they have already been freed). Saved intermediate values "
"of the graph are freed when you call .backward() or autograd.grad(). Specify "
"retain_graph=True if you need to backward through the graph a second time or "
"if you need to access saved variables after calling backward.";
}} // namespace torch::autograd