mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix grad_fn bindings when saved variable freed (#56499)
Summary: Fixes https://github.com/pytorch/pytorch/issues/54472 Adds HANDLE_TH_ERRORS to python bindings for grad_fn attrs and updates tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/56499 Reviewed By: albanD Differential Revision: D27920742 Pulled By: soulitzer fbshipit-source-id: d4f7ac8c0aa2173d25517277c393f8c66de68951
This commit is contained in:
parent
679cc7eb13
commit
2128a84a69
|
|
@ -4585,6 +4585,11 @@ for shape in [(1,), ()]:
|
|||
self.assertEqual(out.grad_fn._saved_dim, 0) # int64_t -> int
|
||||
self.assertIsInstance(out.grad_fn._saved_dim, int)
|
||||
|
||||
out.sum().backward()
|
||||
with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
|
||||
out.grad_fn._saved_tensors
|
||||
self.assertEqual(out.grad_fn._saved_dim, 0)
|
||||
|
||||
a = torch.ones(2, 2, requires_grad=True)
|
||||
indices = torch.tensor([0, 1])
|
||||
out = a[:, indices]
|
||||
|
|
@ -4646,6 +4651,19 @@ for shape in [(1,), ()]:
|
|||
out = torch.tanh(a)
|
||||
self.assertEqual(out, out.grad_fn._saved_result) # saved variable when output
|
||||
|
||||
a = torch.randn(3, 5, requires_grad=True)
|
||||
b = torch.tensor([1, 0, 4])
|
||||
loss = nn.NLLLoss()
|
||||
out = loss(a, b)
|
||||
self.assertIsNone(out.grad_fn._saved_weight)
|
||||
loss = nn.NLLLoss(weight=torch.ones((5,)))
|
||||
out = loss(a, b)
|
||||
self.assertEqual(out.grad_fn._saved_weight, torch.ones((5,))) # c10:optional<Tensor> -> Tensor?
|
||||
|
||||
out.sum().backward()
|
||||
with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
|
||||
out.grad_fn._saved_weight
|
||||
|
||||
def test_autograd_views_codegen(self):
|
||||
# This is not necessarily the absolute correct behavior, but this is the current
|
||||
# one. This test is here to make sure that any change to this behavior is detected
|
||||
|
|
|
|||
|
|
@ -110,37 +110,59 @@ PY_GETSETDEF_STRUCT = CodeTemplate("""\
|
|||
# Getter templates
|
||||
GETTER_DEFINITION = CodeTemplate("""\
|
||||
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto prop = static_cast<${op}*>(self->cdata.get())->${name};
|
||||
${body}
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
""")
|
||||
|
||||
GETTER_DEFINITION_SAVEDVAR = CodeTemplate("""\
|
||||
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
const auto& prop = static_cast<${op}*>(self->cdata.get())->${name}_;
|
||||
${body}
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
""")
|
||||
|
||||
GETTER_DEFINITION_VEC_SAVEDVAR = CodeTemplate("""\
|
||||
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
const auto *node = static_cast<${op}*>(self->cdata.get());
|
||||
const auto& prop = node->${name}_;
|
||||
if (node->${name}_released_) {
|
||||
PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE);
|
||||
return nullptr;
|
||||
}
|
||||
${body}
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
""")
|
||||
|
||||
GETTER_DEFINITION_OPT = CodeTemplate("""\
|
||||
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto opt_prop = static_cast<${op}*>(self->cdata.get())->${name};
|
||||
if (!opt_prop.has_value()) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
auto prop = opt_prop.value();
|
||||
${body}
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
""")
|
||||
|
||||
GETTER_DEFINITION_OPT_ARRAYREF = CodeTemplate("""\
|
||||
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto opt_prop = static_cast<${op}*>(self->cdata.get())->${name};
|
||||
if (!opt_prop.list.has_value()) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
auto prop = opt_prop.list.value();
|
||||
${body}
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
""")
|
||||
|
||||
|
|
@ -317,7 +339,7 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str
|
|||
release_variables.append(f'{name}_released_ = true;')
|
||||
unpack.append(f'auto {name} = unpack_list({name}_);')
|
||||
asserts.append(f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);')
|
||||
getter_definitions.append(GETTER_DEFINITION_SAVEDVAR.substitute(
|
||||
getter_definitions.append(GETTER_DEFINITION_VEC_SAVEDVAR.substitute(
|
||||
op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR))
|
||||
elif type == ListCType(OptionalCType(BaseCType(tensorT))):
|
||||
saved_variables.append(f'std::vector<SavedVariable> {name}_;')
|
||||
|
|
@ -328,7 +350,7 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str
|
|||
release_variables.append(f'{name}_released_ = true;')
|
||||
unpack.append(f'auto {name} = unpack_opt_list({name}_);')
|
||||
asserts.append(f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);')
|
||||
getter_definitions.append(GETTER_DEFINITION_SAVEDVAR.substitute(
|
||||
getter_definitions.append(GETTER_DEFINITION_VEC_SAVEDVAR.substitute(
|
||||
op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR))
|
||||
elif type == BaseCType(intArrayRefT):
|
||||
saved_variables.append(f'std::vector<int64_t> {name};')
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
#include "torch/csrc/autograd/generated/Functions.h"
|
||||
#include "torch/csrc/autograd/python_cpp_function.h"
|
||||
#include <torch/csrc/autograd/python_variable.h>
|
||||
#include <torch/csrc/autograd/saved_variable.h>
|
||||
|
||||
|
||||
namespace torch { namespace autograd { namespace generated {
|
||||
|
|
|
|||
|
|
@ -140,8 +140,10 @@ Variable SavedVariable::unpack(std::shared_ptr<Node> saved_for) const {
|
|||
}
|
||||
|
||||
const char* ERR_BACKWARD_TWICE =
|
||||
"Trying to backward through the graph a second time, but the saved intermediate "
|
||||
"results have already been freed. Specify retain_graph=True when calling "
|
||||
".backward() or autograd.grad() the first time.";
|
||||
"Trying to backward through the graph a second time (or directly access saved "
|
||||
"variables after they have already been freed). Saved intermediate values "
|
||||
"of the graph are freed when you call .backward() or autograd.grad(). Specify "
|
||||
"retain_graph=True if you need to backward through the graph a second time or "
|
||||
"if you need to access saved variables after calling backward.";
|
||||
|
||||
}} // namespace torch::autograd
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user