mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/29653 I didn't remove is_variable from Tensor for BC reasons, but I did remove as many uses as I could from the codebase. at::impl::variable_excluded_from_dispatch got moved to TensorBody.h so that it's more widely accessible. This diff is NOT semantics preserving. Here are the major differences: - In a number of native operator implementations, we tested that arguments are not variable. I replaced these with asserts that variable is excluded from dispatch. I actually don't think these asserts are really necessary now (they should certainly be true, but it's hard to get it wrong), but I've kept them for old time's sake. At least, they'll detect if you call these functions before you've processed variable (indicating a bug in your kernel.) - There are a number of places where we do a per-tensor test for being a variable, for better error reporting when someone commits Tensor/Variable confusion. Although these tests are substantively the same as the tests above, in these cases I decided to *delete* the test entirely. The reasoning is that in these cases, we didn't really care about dispatch (also, see above; I'm not too sure we really need the dispatch asserts), we cared about Tensor/Variable confusion. Since Tensor/Variable confusion is impossible now, we don't need the tests. One of the key factors which pushed me one way or another was whether or not a function was doing per-tensor validation; if I kept the assert in such functions, I'd repeatedly access the TLS. Even if we want to bring back the asserts, they would have to go somewhere else. Another similar idiom is the number of places we do !x.defined() || x.is_variable(); I treated this equivalently. - nuclear_norm's computation of compute_uv is a bit weird, but I think it's OK to just delete the is_variable case (I *suspect* that it is always the case that self.is_variable(), but it doesn't really matter.) Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Differential Revision: D18496168 Pulled By: ezyang fbshipit-source-id: 5a1ded931e0c10a6b758ba64a8380d34110e0c3e
97 lines
2.8 KiB
C++
97 lines
2.8 KiB
C++
#pragma once
|
|
|
|
#include <torch/csrc/python_headers.h>
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <pybind11/pybind11.h>
|
|
#include <pybind11/stl.h>
|
|
|
|
#include <torch/csrc/DynamicTypes.h>
|
|
#include <torch/csrc/autograd/python_variable.h>
|
|
#include <torch/csrc/utils/python_tuples.h>
|
|
#include <torch/csrc/utils/python_numbers.h>
|
|
|
|
#include <stdexcept>
|
|
#include <utility>
|
|
|
|
namespace py = pybind11;
|
|
|
|
namespace pybind11 { namespace detail {
|
|
|
|
// torch.autograd.Variable <-> at::Tensor conversions (without unwrapping)
|
|
template <>
|
|
struct type_caster<at::Tensor> {
|
|
public:
|
|
PYBIND11_TYPE_CASTER(at::Tensor, _("at::Tensor"));
|
|
|
|
bool load(handle src, bool) {
|
|
PyObject* obj = src.ptr();
|
|
if (THPVariable_Check(obj)) {
|
|
value = reinterpret_cast<THPVariable*>(obj)->cdata;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
static handle
|
|
cast(const at::Tensor& src, return_value_policy /* policy */, handle /* parent */) {
|
|
return handle(THPVariable_Wrap(torch::autograd::Variable(src)));
|
|
}
|
|
};
|
|
|
|
template<> struct type_caster<torch::autograd::Variable> {
|
|
public:
|
|
PYBIND11_TYPE_CASTER(torch::autograd::Variable, _("torch::autograd::Variable"));
|
|
bool load(handle src, bool) {
|
|
PyObject *source = src.ptr();
|
|
if (THPVariable_Check(source)) {
|
|
value = ((THPVariable*)source)->cdata;
|
|
return true;
|
|
} else {
|
|
return false;
|
|
}
|
|
}
|
|
static handle cast(torch::autograd::Variable src, return_value_policy /* policy */, handle /* parent */) {
|
|
return handle(THPVariable_Wrap(std::move(src)));
|
|
}
|
|
};
|
|
|
|
template<> struct type_caster<at::IntArrayRef> {
|
|
public:
|
|
PYBIND11_TYPE_CASTER(at::IntArrayRef, _("at::IntArrayRef"));
|
|
|
|
bool load(handle src, bool) {
|
|
PyObject *source = src.ptr();
|
|
auto tuple = PyTuple_Check(source);
|
|
if (tuple || PyList_Check(source)) {
|
|
auto size = tuple ? PyTuple_GET_SIZE(source) : PyList_GET_SIZE(source);
|
|
v_value.resize(size);
|
|
for (int idx = 0; idx < size; idx++) {
|
|
PyObject* obj = tuple ? PyTuple_GET_ITEM(source, idx) : PyList_GET_ITEM(source, idx);
|
|
if (THPVariable_Check(obj)) {
|
|
v_value[idx] = THPVariable_Unpack(obj).item<int64_t>();
|
|
} else if (PyLong_Check(obj)) {
|
|
// use THPUtils_unpackLong after it is safe to include python_numbers.h
|
|
v_value[idx] = THPUtils_unpackLong(obj);
|
|
} else {
|
|
return false;
|
|
}
|
|
}
|
|
value = v_value;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
static handle cast(at::IntArrayRef src, return_value_policy /* policy */, handle /* parent */) {
|
|
return handle(THPUtils_packInt64Array(src.size(), src.data()));
|
|
}
|
|
private:
|
|
std::vector<int64_t> v_value;
|
|
};
|
|
|
|
// Pybind11 bindings for our optional type.
|
|
// http://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html#c-17-library-containers
|
|
template <typename T>
|
|
struct type_caster<c10::optional<T>> : optional_caster<c10::optional<T>> {};
|
|
}} // namespace pybind11::detail
|