pytorch/torch/csrc/utils/pybind.h
Edward Yang 65bb34d885 Remove TensorImpl::is_variable, deprecate Tensor::is_variable (#29653)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29653

I didn't remove is_variable from Tensor for BC reasons, but I did
remove as many uses as I could from the codebase.
at::impl::variable_excluded_from_dispatch got moved to TensorBody.h
so that it's more widely accessible.

This diff is NOT semantics preserving.  Here are the major differences:

- In a number of native operator implementations, we tested that arguments
  are not variable.  I replaced these with asserts that variable is
  excluded from dispatch.  I actually don't think these asserts are really
  necessary now (they should certainly be true, but it's hard to get
  it wrong), but I've kept them for old time's sake.  At least, they'll detect
  if you call these functions before you've processed variable (indicating
  a bug in your kernel.)

- There are a number of places where we do a per-tensor test for being a
  variable, for better error reporting when someone commits Tensor/Variable
  confusion.  Although these tests are substantively the same as the
  tests above, in these cases I decided to *delete* the test entirely.
  The reasoning is that in these cases, we didn't really care about
  dispatch (also, see above; I'm not too sure we really need the dispatch
  asserts), we cared about Tensor/Variable confusion.  Since Tensor/Variable
  confusion is impossible now, we don't need the tests.  One of the key
  factors which pushed me one way or another was whether or not a function
  was doing per-tensor validation; if I kept the assert in such functions,
  I'd repeatedly access the TLS.  Even if we want to bring back the asserts,
  they would have to go somewhere else.

  Another similar idiom is the number of places we do !x.defined() ||
  x.is_variable(); I treated this equivalently.

- nuclear_norm's computation of compute_uv is a bit weird, but I think
  it's OK to just delete the is_variable case (I *suspect* that it is
  always the case that self.is_variable(), but it doesn't really matter.)

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Differential Revision: D18496168

Pulled By: ezyang

fbshipit-source-id: 5a1ded931e0c10a6b758ba64a8380d34110e0c3e
2019-11-14 11:41:02 -08:00

97 lines
2.8 KiB
C++

#pragma once
#include <torch/csrc/python_headers.h>
#include <ATen/ATen.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/csrc/DynamicTypes.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/utils/python_tuples.h>
#include <torch/csrc/utils/python_numbers.h>
#include <stdexcept>
#include <utility>
namespace py = pybind11;
namespace pybind11 { namespace detail {
// torch.autograd.Variable <-> at::Tensor conversions (without unwrapping)
template <>
struct type_caster<at::Tensor> {
public:
PYBIND11_TYPE_CASTER(at::Tensor, _("at::Tensor"));
bool load(handle src, bool) {
PyObject* obj = src.ptr();
if (THPVariable_Check(obj)) {
value = reinterpret_cast<THPVariable*>(obj)->cdata;
return true;
}
return false;
}
static handle
cast(const at::Tensor& src, return_value_policy /* policy */, handle /* parent */) {
return handle(THPVariable_Wrap(torch::autograd::Variable(src)));
}
};
template<> struct type_caster<torch::autograd::Variable> {
public:
PYBIND11_TYPE_CASTER(torch::autograd::Variable, _("torch::autograd::Variable"));
bool load(handle src, bool) {
PyObject *source = src.ptr();
if (THPVariable_Check(source)) {
value = ((THPVariable*)source)->cdata;
return true;
} else {
return false;
}
}
static handle cast(torch::autograd::Variable src, return_value_policy /* policy */, handle /* parent */) {
return handle(THPVariable_Wrap(std::move(src)));
}
};
template<> struct type_caster<at::IntArrayRef> {
public:
PYBIND11_TYPE_CASTER(at::IntArrayRef, _("at::IntArrayRef"));
bool load(handle src, bool) {
PyObject *source = src.ptr();
auto tuple = PyTuple_Check(source);
if (tuple || PyList_Check(source)) {
auto size = tuple ? PyTuple_GET_SIZE(source) : PyList_GET_SIZE(source);
v_value.resize(size);
for (int idx = 0; idx < size; idx++) {
PyObject* obj = tuple ? PyTuple_GET_ITEM(source, idx) : PyList_GET_ITEM(source, idx);
if (THPVariable_Check(obj)) {
v_value[idx] = THPVariable_Unpack(obj).item<int64_t>();
} else if (PyLong_Check(obj)) {
// use THPUtils_unpackLong after it is safe to include python_numbers.h
v_value[idx] = THPUtils_unpackLong(obj);
} else {
return false;
}
}
value = v_value;
return true;
}
return false;
}
static handle cast(at::IntArrayRef src, return_value_policy /* policy */, handle /* parent */) {
return handle(THPUtils_packInt64Array(src.size(), src.data()));
}
private:
std::vector<int64_t> v_value;
};
// Pybind11 bindings for our optional type.
// http://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html#c-17-library-containers
template <typename T>
struct type_caster<c10::optional<T>> : optional_caster<c10::optional<T>> {};
}} // namespace pybind11::detail