pytorch/torch/csrc/autograd/python_legacy_variable.cpp
Will Feng 317cf7c874 Remove tensor_data() call in Python Variable() and nn.Parameter() constructors (#22821)
Summary:
As part of the Variable/Tensor merge, `variable.tensor_data()` should be removed in favor of `variable.detach()`. This PR removes  `tensor_data()` call sites in Python `Variable()` and `nn.Parameter()` constructor paths.

Note that this PR is BC-breaking in the following way:
- For Python `Variable()` constructor:
Previously, in-place updating a tensor after it's been used to create a Variable does not bump the Variable's version counter, which causes the following problem:
```python
t = torch.ones(2, 3)
v = torch.autograd.Variable(t).requires_grad_()
y = v * v
t.add_(1)  # This bumps version counter of `t`
y.sum().backward()  # This computes `v`'s gradient incorrectly before this patch, and throws error after this patch
```
After this patch, in-place updating a tensor after it's been used to create a Variable will also bump the Variable's version counter, thus preserving the correctness of the Variable's version counter.

- For Python `nn.Parameter()` constructor:
Previously, in-place updating a tensor after it's been used to create an nn.Parameter does not bump the nn.Parameter's version counter, which causes the following problem:
```python
t = torch.ones(2, 3)
v = torch.nn.Parameter(t)
y = v * v
t.add_(1)  # This bumps version counter of `t`
y.sum().backward()  # This computes `v`'s gradient incorrectly before this patch, and throws error after this patch
```
After this patch, in-place updating a tensor after it's been used to create an nn.Parameter will also bump the nn.Parameter's version counter, thus preserving the correctness of the nn.Parameter's version counter.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22821

Differential Revision: D16258030

Pulled By: yf225

fbshipit-source-id: 9a6d68cea1864893193dbefbb6ef0c1d5ca12d78
2019-07-14 21:09:29 -07:00

146 lines
5.6 KiB
C++

#include <torch/csrc/autograd/python_legacy_variable.h>
#include <ATen/ATen.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/autograd/python_function.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/tensor/python_tensor.h>
#include <torch/csrc/jit/tracer.h>
using namespace at;
namespace torch { namespace autograd {
static PyObject *THPVariable_pynew(PyTypeObject* type, PyObject *args, PyObject *kwds) {
HANDLE_TH_ERRORS
THPObjectPtr _data;
PyObject *data = nullptr;
PyObject *grad_fn = nullptr;
char is_volatile = 0;
char requires_grad = 0;
const char* name = nullptr;
const char *accepted_args[] = {"data", "requires_grad", "volatile", "_grad_fn", "name", nullptr};
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|ObbOz", (char**)accepted_args,
&data, &requires_grad, &is_volatile, &grad_fn, &name))
return nullptr;
if (grad_fn == Py_None)
grad_fn = nullptr;
if (is_volatile) {
PyErr_WarnEx(PyExc_UserWarning,
"volatile was removed and now has no effect. Use `with torch.no_grad():` "
"instead.", 1);
}
if (is_volatile && requires_grad) {
throw ValueError("Variable can't be volatile and require_grad at the same time!");
}
if (grad_fn && !THPFunction_Check(grad_fn)) {
throw TypeError("_grad_fn has to be a Function object or None, but got %s",
Py_TYPE(grad_fn)->tp_name);
}
Variable var;
if (!data || data == Py_None) {
// For legacy serialization code, create an empty tensor. This is also used
// by nn.Parameter() with no arguments.
auto type_id = torch::tensors::get_default_tensor_type_id();
auto scalar_type = torch::tensors::get_default_scalar_type();
auto options = TensorOptions(scalar_type)
.device(computeDeviceType(type_id))
.layout(layout_from_backend(tensorTypeIdToBackend(type_id)))
.is_variable(true);
var = at::empty({0}, options);
} else if (THPVariable_Check(data)) {
var = ((THPVariable*)data)->cdata.detach();
} else {
throw torch::TypeError("Variable data has to be a tensor, but got %s",
Py_TYPE(data)->tp_name);
}
// We set `tensor`'s `allow_tensor_metadata_change` to true here, because we want to
// allow the following use case for backward compatibility:
//
// ```python
// var = Variable(torch.randn(2, 3))
// var.resize_(4, 5)
// ```
var.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
if (grad_fn) {
auto grad_fn_ = THPFunction_asFunction((THPFunction*)grad_fn);
Edge edge(grad_fn_, grad_fn_->add_input_metadata(var));
var.set_gradient_edge(std::move(edge));
} else {
var.set_requires_grad(requires_grad);
}
if (name) {
var.set_name(name);
}
if (jit::tracer::isTracing() && data && data != Py_None && THPVariable_Check(data)) {
if (auto *v = jit::tracer::getValueTrace(((THPVariable*)data)->cdata)) {
jit::tracer::setValueTrace(var, v);
}
}
return THPVariable_Wrap(std::move(var));
END_HANDLE_TH_ERRORS
}
PyTypeObject THPLegacyVariableType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"torch._C._LegacyVariableBase", /* tp_name */
0, /* tp_basicsize */
0, /* tp_itemsize */
nullptr, /* tp_dealloc */
nullptr, /* tp_print */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
nullptr, /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
nullptr, /* tp_hash */
nullptr, /* tp_call */
nullptr, /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
nullptr, /* tp_doc */
nullptr, /* tp_traverse */
nullptr, /* tp_clear */
nullptr, /* tp_richcompare */
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
nullptr, /* tp_methods */
nullptr, /* tp_members */
nullptr, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */
nullptr, /* tp_descr_set */
0, /* tp_dictoffset */
nullptr, /* tp_init */
nullptr, /* tp_alloc */
THPVariable_pynew /* tp_new */
};
void init_legacy_variable(PyObject *module) {
if (PyType_Ready(&THPLegacyVariableType) < 0) {
throw python_error();
}
auto obj = (PyObject*)&THPLegacyVariableType;
Py_INCREF(obj);
if (PyModule_AddObject(module, "_LegacyVariableBase", obj) < 0) {
throw python_error();
}
}
}} // namespace torch::autograd