pytorch/tools/autograd/templates/python_variable_methods.cpp
Xiang Gao 2e5a8cee82 Customize the printing of namedtuple return (#17136)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/17112
```python
print("good", torch.randn(5,5,5).max(1))
print("terrible", torch.randn(5,5,10).max(1))
print("not as good", torch.randn(5,5,500).max(1))
print ("old behaviour = gold standard")
print(tuple(torch.randn(5,5,5).max(1)))
print(tuple(torch.randn(5,5,10).max(1)))
print(tuple(torch.randn(5,5,500).max(1)))
```
now gives
```
>>> import torch
>>> print("good", torch.randn(5,5,5).max(1))
good torch.return_types.max(
values=tensor([[ 1.2821,  1.8063,  1.8075,  1.3082, -0.1267],
        [ 0.3437,  0.7353,  1.2619,  0.7557,  1.6662],
        [ 0.8583,  1.8906,  1.0246,  1.7598,  1.1184],
        [ 1.7821,  0.0230,  0.9452,  1.0318,  1.0823],
        [ 0.4116, -0.0379, -0.1843,  1.4129,  1.8796]]),
indices=tensor([[4, 4, 3, 2, 1],
        [1, 2, 4, 1, 1],
        [2, 4, 0, 2, 1],
        [0, 2, 0, 3, 1],
        [0, 4, 4, 4, 4]]))
>>> print("terrible", torch.randn(5,5,10).max(1))
terrible torch.return_types.max(
values=tensor([[ 2.1272,  1.3664,  2.2067,  1.3974, -0.0883,  1.2505,  1.0074,  1.1217,
          0.3849,  0.6936],
        [ 0.6288, -0.4560,  1.2748,  1.5482,  1.2777,  1.6874,  0.7151,  0.6041,
          1.3572,  1.6232],
        [ 1.6703,  1.0075,  1.6480,  2.2839,  1.3390,  0.4938,  1.6449,  1.7628,
          0.8141,  2.5714],
        [ 0.7079,  1.8677,  3.2478,  1.5591,  2.4870,  0.8635, -0.1450,  1.6923,
          1.4924,  1.6298],
        [ 2.4056,  0.8002,  0.9317,  0.7455,  0.7866,  2.1191,  0.3492,  1.2095,
          1.8637,  1.7470]]),
indices=tensor([[1, 1, 0, 0, 0, 0, 3, 4, 4, 4],
        [4, 2, 2, 1, 2, 2, 3, 1, 1, 3],
        [0, 3, 3, 0, 2, 1, 4, 1, 0, 1],
        [4, 1, 3, 0, 3, 2, 0, 1, 4, 3],
        [1, 0, 3, 2, 1, 0, 0, 1, 0, 1]]))
>>> print("not as good", torch.randn(5,5,500).max(1))
not as good torch.return_types.max(
values=tensor([[ 0.3877,  0.7873,  1.8701,  ...,  0.5971,  1.6103, -0.3435],
        [ 1.1300,  2.2418,  1.4239,  ...,  1.3943,  0.3872,  1.6475],
        [ 2.0656,  1.3136,  0.9896,  ...,  2.3918,  0.8226,  1.0517],
        [ 1.1054,  0.9945,  1.0561,  ...,  2.1039,  1.1524,  3.0304],
        [ 1.5041,  2.2809,  1.0883,  ...,  0.8504,  2.4774,  1.1041]]),
indices=tensor([[4, 3, 1,  ..., 1, 4, 0],
        [4, 4, 4,  ..., 3, 0, 3],
        [3, 0, 1,  ..., 2, 2, 4],
        [0, 1, 1,  ..., 4, 2, 2],
        [1, 0, 4,  ..., 2, 0, 2]]))
>>> print ("old behaviour = gold standard")
old behaviour = gold standard
>>> print(tuple(torch.randn(5,5,5).max(1)))
(tensor([[ 1.1908,  1.1807,  1.3151,  1.7184,  0.3556],
        [ 0.3798,  0.9213,  0.3001,  1.3087,  2.2419],
        [ 1.4233,  1.4814,  1.9900,  1.7744,  1.3059],
        [ 1.0026, -0.0330,  1.3061,  1.8730,  2.0685],
        [ 1.3041,  1.6458,  1.3449,  1.8948,  3.6206]]), tensor([[0, 4, 3, 4, 0],
        [1, 1, 4, 0, 4],
        [4, 1, 0, 3, 3],
        [1, 2, 1, 4, 0],
        [3, 3, 0, 3, 3]]))
>>> print(tuple(torch.randn(5,5,10).max(1)))
(tensor([[-0.1232,  0.8275,  0.6732,  1.1223,  0.8247,  1.2851,  1.6009,  1.9979,
          1.9109,  0.7313],
        [ 0.2260,  0.5922,  1.6928,  0.6024,  2.1158,  3.0619,  0.5653,  0.7426,
          0.8316,  0.6346],
        [ 0.4319,  0.2231,  0.5255,  1.7620,  1.1657,  0.8875,  0.5782,  0.6506,
          0.5032,  1.7097],
        [ 0.4137,  1.7265,  1.4260,  2.0301,  1.2244,  0.7128,  2.6345,  0.7230,
          1.3553,  1.6508],
        [ 1.0684,  1.7195,  1.4068,  0.7076, -0.0242,  0.8474,  0.8754,  1.7108,
          0.2188,  1.1584]]), tensor([[0, 1, 3, 4, 2, 3, 4, 2, 1, 0],
        [1, 4, 0, 0, 3, 2, 0, 0, 3, 3],
        [2, 3, 1, 1, 4, 0, 1, 4, 4, 4],
        [0, 4, 1, 3, 2, 0, 2, 0, 3, 1],
        [1, 0, 0, 0, 0, 3, 3, 3, 2, 0]]))
>>> print(tuple(torch.randn(5,5,500).max(1)))
(tensor([[0.9395, 1.5572, 1.8797,  ..., 2.0494, 0.8202, 0.9623],
        [1.7937, 0.7225, 1.8836,  ..., 0.7927, 1.4976, 1.1813],
        [0.8558, 1.6943, 1.4192,  ..., 0.8327, 1.9661, 0.4197],
        [1.2993, 1.4995, 0.9357,  ..., 0.7810, 1.3030, 2.6216],
        [1.4206, 1.8315, 1.0338,  ..., 1.4312, 1.3198, 1.5233]]), tensor([[0, 4, 3,  ..., 3, 0, 2],
        [0, 1, 0,  ..., 0, 4, 3],
        [3, 4, 3,  ..., 3, 0, 0],
        [3, 2, 3,  ..., 1, 2, 1],
        [1, 2, 4,  ..., 3, 1, 3]]))
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17136

Differential Revision: D14250021

Pulled By: VitalyFedyunin

fbshipit-source-id: aae72f03b35980063b1ac1f07b8353eddb0c8b93
2019-02-28 13:07:26 -08:00

707 lines
26 KiB
C++

// ${generated_comment}
#include <Python.h>
#include "torch/csrc/DynamicTypes.h"
#include "torch/csrc/Exceptions.h"
#include "torch/csrc/Size.h"
#include "torch/csrc/autograd/python_variable.h"
#include "torch/csrc/autograd/utils/python_error_messages.h"
#include "torch/csrc/autograd/utils/wrap_outputs.h"
#include "torch/csrc/autograd/utils/python_arg_parsing.h"
#include "torch/csrc/jit/tracer.h"
#ifdef USE_CUDA
#include "torch/csrc/cuda/Stream.h"
#include "torch/csrc/cuda/Event.h"
#endif
#include "torch/csrc/utils/cuda_lazy_init.h"
#include "torch/csrc/utils/object_ptr.h"
#include "torch/csrc/utils/python_arg_parser.h"
#include "torch/csrc/utils/python_numbers.h"
#include "torch/csrc/utils/python_strings.h"
#include "torch/csrc/utils/python_tuples.h"
#include "torch/csrc/utils/tensor_apply.h"
#include "torch/csrc/utils/tensor_list.h"
#include "torch/csrc/utils/tensor_new.h"
#include "torch/csrc/utils/tensor_numpy.h"
#include "torch/csrc/utils/tensor_types.h"
#include "torch/csrc/utils/structseq.h"
#include <ATen/ATen.h>
#include "c10/util/Optional.h"
#include "python_variable_methods_dispatch.h"
#include <stdexcept>
using at::DeviceGuard;
using at::device_of;
using at::OptionalDeviceGuard;
using at::Backend;
using at::Scalar;
using at::ScalarType;
using at::Tensor;
using namespace torch::autograd::utils;
namespace torch { namespace autograd {
static PyObject * THPVariable_apply_(PyObject* self, PyObject* arg)
{
HANDLE_TH_ERRORS
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
if (self_.requires_grad()) {
throw std::runtime_error(
"Can't call apply_() on Variable that requires grad. Use "
"var.detach().apply_() instead.");
}
return THPVariable_Wrap(torch::utils::apply_(self_, arg));
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_size(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"size(int64_t dim)",
"size()",
});
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
ParsedArgs<3> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
if (jit::tracer::isTracing()) {
return wrap(jit::tracer::getSizeOf(self_, r.toInt64(0)));
} else {
return wrap(self_.size(r.toInt64(0)));
}
} else if (r.idx == 1) {
// we can't do the normal wrapping here because IntArrayRef maps to both
// torch.Size and tuple in python.
return THPSize_New(self_);
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_stride(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"stride(int64_t dim)",
"stride()",
});
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
ParsedArgs<3> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
return wrap(self_.stride(r.toInt64(0)));
} else if (r.idx == 1) {
// yes, this is called strides in ATen.
IntArrayRef strides = self_.strides();
// we can't do the normal wrapping here because IntArrayRef maps to both
// torch.Size and tuple in python
return THPUtils_packInt64Array(strides.size(), strides.data());
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_get_device(PyObject* self_, PyObject* args)
{
HANDLE_TH_ERRORS
auto& self = reinterpret_cast<THPVariable*>(self_)->cdata;
return wrap(self.get_device());
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_storage_offset(PyObject* self_, PyObject* args)
{
HANDLE_TH_ERRORS
auto& self = reinterpret_cast<THPVariable*>(self_)->cdata;
return wrap(self.storage_offset());
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_dim(PyObject* self, PyObject* args)
{
HANDLE_TH_ERRORS
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
return THPUtils_packInt64(self_.dim());
END_HANDLE_TH_ERRORS
}
static Tensor dispatch_contiguous(const Tensor & self) {
AutoNoGIL no_gil;
OptionalDeviceGuard device_guard(device_of(self));
return self.contiguous();
}
static PyObject * THPVariable_contiguous(PyObject* self, PyObject* args)
{
HANDLE_TH_ERRORS
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
// avoids touching the GIL or current device if self is already contiguous
if (self_.is_contiguous()) {
// NOTE: this logic is duplicated from VariableType.cpp. Since we need to
// record this call to contiguous() in the trace regardless of whether
// we actually call contiguous here, we need to record this information
// manually.
if (jit::tracer::isTracing()) {
auto tracer_state = jit::tracer::getTracingState();
auto node = tracer_state->graph->create(jit::aten::contiguous, /*num_outputs=*/0);
jit::tracer::recordSourceLocation(node);
jit::tracer::addInputs(node, "self", self_);
tracer_state->graph->insertNode(node);
jit::tracer::addOutput(node, self_);
}
Py_INCREF(self);
return self;
}
return THPVariable_Wrap(dispatch_contiguous(self_));
END_HANDLE_TH_ERRORS
}
static Tensor dispatch_copy_(Tensor & self, const Tensor & other, bool non_blocking) {
AutoNoGIL no_gil;
OptionalDeviceGuard device_guard(device_of(self));
return self.copy_(other, non_blocking);
}
static PyObject * THPVariable_copy_(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"copy_(Tensor other, bool non_blocking=False)",
"copy_(Tensor other, bool async=False)|deprecated"
});
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
ParsedArgs<2> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
return THPVariable_Wrap(dispatch_copy_(self_, r.tensor(0), r.toBool(1)));
END_HANDLE_TH_ERRORS
}
static double dispatch_to_CDouble(const Tensor & self) {
AutoNoGIL no_gil;
OptionalDeviceGuard device_guard(device_of(self));
if (self.numel() != 1) {
throw ValueError("only one element tensors can be converted to Python scalars");
}
return self.item<double>();
}
static std::complex<double> dispatch_to_CComplexDouble(const Tensor & self) {
AutoNoGIL no_gil;
OptionalDeviceGuard device_guard(device_of(self));
if (self.numel() != 1) {
throw ValueError("only one element tensors can be converted to Python scalars");
}
return self.item<std::complex<double>>();
}
static int64_t dispatch_to_CLong(const Tensor & self) {
AutoNoGIL no_gil;
OptionalDeviceGuard device_guard(device_of(self));
if (self.numel() != 1) {
throw ValueError("only one element tensors can be converted to Python scalars");
}
return self.item<int64_t>();
}
static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
jit::tracer::warn("Converting a tensor to a Python float", jit::tracer::WARN_PYTHON_DATAFLOW);
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
return wrap(dispatch_to_CDouble(self_));
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_integral_scalar(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
jit::tracer::warn("Converting a tensor to a Python integer", jit::tracer::WARN_PYTHON_DATAFLOW);
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
if (isFloatingType(self_.type().scalarType())) {
// we can't dispatch to item<int64_t> here because we want to avoid ATen overflow checks;
// the python integral type (long in python2) can't overflow.
return THPUtils_packDoubleAsInt(dispatch_to_CDouble(self_));
} else {
return wrap(dispatch_to_CLong(self_));
}
END_HANDLE_TH_ERRORS
}
// This is the __index__ function in Python which is similar to __int__, but
// called when used as a slice.
static PyObject * THPVariable_index_scalar(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
jit::tracer::warn("Converting a tensor to a Python index", jit::tracer::WARN_PYTHON_DATAFLOW);
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
// TODO: change the condition to `self_.dim() != 0` once we expose scalars
// in PyTorch.
if (!isIntegralType(self_.type().scalarType()) || self_.numel() != 1) {
throw TypeError("only integer tensors of a single element can be converted to an index");
}
return wrap(dispatch_to_CLong(self_));
END_HANDLE_TH_ERRORS
}
static Tensor dispatch_invert(const Tensor & self) {
AutoNoGIL no_gil;
OptionalDeviceGuard device_guard(device_of(self));
return 1 - self;
}
static PyObject * THPVariable_invert(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
if (self_.type().scalarType() != at::kByte) {
throw TypeError("~ (operator.invert) is only implemented on byte tensors");
}
return THPVariable_Wrap(dispatch_invert(self_));
END_HANDLE_TH_ERRORS
}
static Tensor dispatch_to(const Tensor & self, Device device, bool non_blocking, bool copy) {
AutoNoGIL no_gil;
// NOTE: this is where we record aten::to in the graph during tracing. However, the behavior of aten::to
// is different with respect to TensorOptions fields that are not present: aten::to inherits fields that
// are missing from the self argument while the tracer assumes that they should be populated with the
// default values (eg. float for scalar type). By explicitly copying over the tensor options here we fully
// specify all tensor options and thus record the proper trace
return self.to(self.options().device(device), non_blocking, copy);
}
static Tensor dispatch_to(const Tensor & self, ScalarType dtype, bool non_blocking, bool copy) {
AutoNoGIL no_gil;
return self.to(dtype, non_blocking, copy);
}
static Tensor dispatch_to(const Tensor & self, Device device, ScalarType dtype, bool non_blocking, bool copy) {
AutoNoGIL no_gil;
return self.to(device, dtype, non_blocking, copy);
}
static PyObject * THPVariable_cpu(PyObject* self, PyObject* args)
{
HANDLE_TH_ERRORS
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
return THPVariable_Wrap(dispatch_to(self_, at::Device(at::DeviceType::CPU), false, false));
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_cuda(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"cuda(Device? device=None, bool non_blocking=False)",
"cuda(Device? device=None, bool async=False)|deprecated"
});
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
ParsedArgs<2> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
auto device = r.isNone(0) ? at::Device(at::DeviceType::CUDA) : r.device(0);
AT_CHECK(device.is_cuda(), "Invalid device, must be cuda device");
torch::utils::cuda_lazy_init();
return THPVariable_Wrap(dispatch_to(self_, device, r.toBool(1), false));
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_to_type(PyObject* self, ScalarType scalarType) {
HANDLE_TH_ERRORS
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
return THPVariable_Wrap(dispatch_to(self_, scalarType, false, false));
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_byte(PyObject* self, PyObject* args) {
return THPVariable_to_type(self, ScalarType::Byte);
}
static PyObject * THPVariable_char(PyObject* self, PyObject* args) {
return THPVariable_to_type(self, ScalarType::Char);
}
static PyObject * THPVariable_double(PyObject* self, PyObject* args) {
return THPVariable_to_type(self, ScalarType::Double);
}
static PyObject * THPVariable_float(PyObject* self, PyObject* args) {
return THPVariable_to_type(self, ScalarType::Float);
}
static PyObject * THPVariable_half(PyObject* self, PyObject* args) {
return THPVariable_to_type(self, ScalarType::Half);
}
static PyObject * THPVariable_int(PyObject* self, PyObject* args) {
return THPVariable_to_type(self, ScalarType::Int);
}
static PyObject * THPVariable_long(PyObject* self, PyObject* args) {
return THPVariable_to_type(self, ScalarType::Long);
}
static PyObject * THPVariable_short(PyObject* self, PyObject* args) {
return THPVariable_to_type(self, ScalarType::Short);
}
static PyObject * THPVariable_element_size(PyObject* self, PyObject* args)
{
HANDLE_TH_ERRORS
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
size_t element_size = self_.type().elementSizeInBytes();
return THPUtils_packInt64(element_size);
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_numpy(PyObject* self, PyObject* arg)
{
HANDLE_TH_ERRORS
jit::tracer::warn("Converting a tensor to a NumPy array", jit::tracer::WARN_PYTHON_DATAFLOW);
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
if (self_.requires_grad()) {
throw std::runtime_error(
"Can't call numpy() on Variable that requires grad. "
"Use var.detach().numpy() instead.");
}
return torch::utils::tensor_to_numpy(self_.data());
END_HANDLE_TH_ERRORS
}
// TODO: move this to ATen. We would need to expose Stream objects in ATen.
static PyObject * THPVariable_record_stream(PyObject* self, PyObject* arg)
{
HANDLE_TH_ERRORS
#ifdef USE_CUDA
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
if (!THCPStream_Check(arg)) {
return PyErr_Format(PyExc_TypeError, "expected Stream object");
}
void* data = self_.data_ptr();
c10::cuda::CUDACachingAllocator::recordStream(data, at::cuda::CUDAStream::unpack(((THCPStream*)arg)->cdata));
Py_RETURN_NONE;
#else
throw std::runtime_error("PyTorch compiled without CUDA support");
#endif
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_requires_grad_(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"requires_grad_(bool requires_grad=True)",
});
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
ParsedArgs<1> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
auto requires_grad = r.toBool(0);
// should we throw if requires_grad is true? var.requires_grad = True throws here
// but it's nice to let this be a no-op.
if (!self_.is_leaf() && !requires_grad) {
throw std::runtime_error(autograd::utils::requires_grad_leaf_error(requires_grad));
}
if (requires_grad && !self_.is_floating_point()) {
throw std::runtime_error("only Tensors of floating point dtype can require gradients");
}
self_.set_requires_grad(requires_grad);
return THPVariable_Wrap(self_);
END_HANDLE_TH_ERRORS
}
inline bool dispatch_is_contiguous(Tensor & self) {
return self.is_contiguous();
}
static PyObject * THPVariable_is_contiguous(PyObject* self_, PyObject* args)
{
HANDLE_TH_ERRORS
auto& self = reinterpret_cast<THPVariable*>(self_)->cdata;
return wrap(dispatch_is_contiguous(self));
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_item(PyObject* self, PyObject* args)
{
HANDLE_TH_ERRORS
jit::tracer::warn("Converting a tensor to a Python number", jit::tracer::WARN_PYTHON_DATAFLOW);
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
if (self_.is_floating_point()) {
return wrap(dispatch_to_CDouble(self_));
} else if (self_.is_complex()) {
return wrap(dispatch_to_CComplexDouble(self_));
} else {
return wrap(dispatch_to_CLong(self_));
}
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_map_(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({ "map_(Tensor other, PyObject* callable)" });
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
ParsedArgs<2> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
Variable other = r.tensor(0);
if (self_.requires_grad() || other.requires_grad()) {
throw std::runtime_error(
"Can't call map_() on Variable that requires grad. Use "
"var.detach().map_() instead.");
}
return THPVariable_Wrap(torch::utils::map_(self_, other, r.pyobject(1)));
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_map2_(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({ "map2_(Tensor x, Tensor y, PyObject* callable)" });
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
ParsedArgs<3> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
Variable x = r.tensor(0);
Variable y = r.tensor(1);
if (self_.requires_grad() || x.requires_grad() || y.requires_grad()) {
throw std::runtime_error(
"Can't call map2_() on Variable that requires grad. Use "
"var.detach().map2_() instead.");
}
return THPVariable_Wrap(torch::utils::map2_(self_, x, y, r.pyobject(2)));
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_new(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
OptionalDeviceGuard device_guard(device_of(self_));
return THPVariable_Wrap(torch::utils::legacy_tensor_new(self_.type(), args, kwargs));
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_new_empty(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
jit::tracer::warn("new_empty", jit::tracer::LEGACY_CONSTRUCTOR);
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
OptionalDeviceGuard device_guard(device_of(self_));
return THPVariable_Wrap(torch::utils::new_empty(self_.type(), args, kwargs));
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_new_full(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
jit::tracer::warn("new_full", jit::tracer::LEGACY_CONSTRUCTOR);
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
OptionalDeviceGuard device_guard(device_of(self_));
return THPVariable_Wrap(torch::utils::new_full(self_.type(), args, kwargs));
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_new_ones(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
jit::tracer::warn("new_ones", jit::tracer::LEGACY_CONSTRUCTOR);
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
OptionalDeviceGuard device_guard(device_of(self_));
return THPVariable_Wrap(torch::utils::new_ones(self_.type(), args, kwargs));
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_new_tensor(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
jit::tracer::warn("new_tensor", jit::tracer::LEGACY_CONSTRUCTOR);
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
OptionalDeviceGuard device_guard(device_of(self_));
return THPVariable_Wrap(torch::utils::new_tensor(self_.type(), args, kwargs));
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_new_zeros(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
jit::tracer::warn("new_zeros", jit::tracer::LEGACY_CONSTRUCTOR);
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
OptionalDeviceGuard device_guard(device_of(self_));
return THPVariable_Wrap(torch::utils::new_zeros(self_.type(), args, kwargs));
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_storage(PyObject* self, PyObject* arg)
{
HANDLE_TH_ERRORS
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
return createPyObject(self_.storage());
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_storage_type(PyObject* self, PyObject* arg)
{
HANDLE_TH_ERRORS
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
auto storage = THPObjectPtr(createPyObject(self_.storage()));
auto storage_type = (PyObject*)Py_TYPE(storage);
Py_INCREF(storage_type);
return storage_type;
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_to(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
auto parsed = parse_to_conversion(args, kwargs, /*allow_copy*/ true);
auto& device = std::get<0>(parsed);
auto& scalarType = std::get<1>(parsed);
auto non_blocking = std::get<2>(parsed);
auto copy = std::get<3>(parsed);
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
if (device && device->is_cuda()) {
torch::utils::cuda_lazy_init();
}
if (!device && !scalarType && !copy) {
Py_INCREF(self);
return self;
} else if (!device) {
return THPVariable_Wrap(dispatch_to(self_, *scalarType, non_blocking, copy));
} else if (!scalarType) {
return THPVariable_Wrap(dispatch_to(self_, *device, non_blocking, copy));
} else {
return THPVariable_Wrap(dispatch_to(self_, *device, *scalarType, non_blocking, copy));
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_tolist(PyObject* self, PyObject* args)
{
HANDLE_TH_ERRORS
jit::tracer::warn("Converting a tensor to a Python list", jit::tracer::WARN_PYTHON_DATAFLOW);
auto self_ = reinterpret_cast<THPVariable*>(self)->cdata;
return torch::utils::tensor_to_list(self_.data());
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_type(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"type(PyObject* dtype=None, bool non_blocking=False)",
"type(PyObject* dtype=None, bool async=False)|deprecated"
});
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
ParsedArgs<2> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.isNone(0)) {
return THPUtils_packString(torch::utils::type_to_string(self_.type()));
}
auto obj = r.pyobject(0);
std::string type_name;
bool is_dtype = false;
if (PyType_Check(obj)) {
if (obj == THPVariableClass) {
type_name = "torch.Tensor";
} else {
type_name = ((PyTypeObject*)obj)->tp_name;
}
} else if (THPUtils_checkString(obj)) {
type_name = THPUtils_unpackString(obj);
} else if (THPDtype_Check(obj)) {
is_dtype = true;
} else {
throw TypeError("dtype must be a type, str, or dtype object");
}
ScalarType scalar_type;
Device device = self_.device();
if (is_dtype) {
scalar_type = r.scalartype(0);
} else {
auto& type = torch::utils::type_from_string(type_name);
scalar_type = type.scalarType();
auto device_type = backendToDeviceType(type.backend());
if (device_type != device.type()) {
device = at::Device(device_type);
}
}
if (device.is_cuda()) {
torch::utils::cuda_lazy_init();
}
return THPVariable_Wrap(dispatch_to(self_, device, scalar_type, /*non_blocking=*/ r.toBool(1), /*copy=*/ false));
END_HANDLE_TH_ERRORS
}
// generated methods start here
${py_methods}
static PyObject * THPVariable_bool(PyObject* self, PyObject* args) {
jit::tracer::warn("Converting a tensor to a Python boolean", jit::tracer::WARN_PYTHON_DATAFLOW);
return THPVariable_is_nonzero(self, args);
}
PyMethodDef variable_methods[] = {
{"__add__", (PyCFunction)THPVariable_add, METH_VARARGS | METH_KEYWORDS, NULL},
{"__radd__", (PyCFunction)THPVariable_add, METH_VARARGS | METH_KEYWORDS, NULL},
{"__iadd__", (PyCFunction)THPVariable_add_, METH_VARARGS | METH_KEYWORDS, NULL},
{"__rmul__", (PyCFunction)THPVariable_mul, METH_VARARGS | METH_KEYWORDS, NULL},
{"__mul__", (PyCFunction)THPVariable_mul, METH_VARARGS | METH_KEYWORDS, NULL},
{"__imul__", (PyCFunction)THPVariable_mul_, METH_VARARGS | METH_KEYWORDS, NULL},
{"__sub__", (PyCFunction)THPVariable_sub, METH_VARARGS | METH_KEYWORDS, NULL},
{"__isub__", (PyCFunction)THPVariable_sub_, METH_VARARGS | METH_KEYWORDS, NULL},
{"__div__", (PyCFunction)THPVariable_div, METH_VARARGS | METH_KEYWORDS, NULL},
{"__truediv__", (PyCFunction)THPVariable_div, METH_VARARGS | METH_KEYWORDS, NULL},
{"__idiv__", (PyCFunction)THPVariable_div_, METH_VARARGS | METH_KEYWORDS, NULL},
{"__mod__", (PyCFunction)THPVariable_remainder, METH_VARARGS | METH_KEYWORDS, NULL},
{"__bool__", (PyCFunction)THPVariable_bool, METH_NOARGS, NULL},
{"__float__", (PyCFunction)THPVariable_float_scalar, METH_NOARGS, NULL},
{"__int__", (PyCFunction)THPVariable_integral_scalar, METH_NOARGS, NULL},
{"__long__", (PyCFunction)THPVariable_integral_scalar, METH_NOARGS, NULL},
{"__index__", (PyCFunction)THPVariable_index_scalar, METH_NOARGS, NULL},
{"__nonzero__", (PyCFunction)THPVariable_bool, METH_NOARGS, NULL},
{"__invert__", (PyCFunction)THPVariable_invert, METH_NOARGS, NULL},
{"__matmul__", (PyCFunction)THPVariable_matmul, METH_VARARGS | METH_KEYWORDS, NULL},
{"apply_", (PyCFunction)THPVariable_apply_, METH_O, NULL},
{"byte", (PyCFunction)THPVariable_byte, METH_NOARGS, NULL},
{"char", (PyCFunction)THPVariable_char, METH_NOARGS, NULL},
{"contiguous", (PyCFunction)THPVariable_contiguous, METH_NOARGS, NULL},
{"copy_", (PyCFunction)THPVariable_copy_, METH_VARARGS | METH_KEYWORDS, NULL},
{"cpu", (PyCFunction)THPVariable_cpu, METH_NOARGS, NULL},
{"cuda", (PyCFunction)THPVariable_cuda, METH_VARARGS | METH_KEYWORDS, NULL},
{"dim", (PyCFunction)THPVariable_dim, METH_NOARGS, NULL},
{"double", (PyCFunction)THPVariable_double, METH_NOARGS, NULL},
{"element_size", (PyCFunction)THPVariable_element_size, METH_NOARGS, NULL},
{"float", (PyCFunction)THPVariable_float, METH_NOARGS, NULL},
{"get_device", (PyCFunction)THPVariable_get_device, METH_NOARGS, NULL},
{"half", (PyCFunction)THPVariable_half, METH_NOARGS, NULL},
{"int", (PyCFunction)THPVariable_int, METH_NOARGS, NULL},
{"is_contiguous", (PyCFunction)THPVariable_is_contiguous, METH_NOARGS, NULL},
{"item", (PyCFunction)THPVariable_item, METH_NOARGS, NULL},
{"long", (PyCFunction)THPVariable_long, METH_NOARGS, NULL},
{"map_", (PyCFunction)THPVariable_map_, METH_VARARGS | METH_KEYWORDS, NULL},
{"map2_", (PyCFunction)THPVariable_map2_, METH_VARARGS | METH_KEYWORDS, NULL},
{"ndimension", (PyCFunction)THPVariable_dim, METH_NOARGS, NULL},
{"nelement", (PyCFunction)THPVariable_numel, METH_NOARGS, NULL},
{"new", (PyCFunction)THPVariable_new, METH_VARARGS | METH_KEYWORDS, NULL},
{"new_empty", (PyCFunction)THPVariable_new_empty, METH_VARARGS | METH_KEYWORDS, NULL},
{"new_full", (PyCFunction)THPVariable_new_full, METH_VARARGS | METH_KEYWORDS, NULL},
{"new_ones", (PyCFunction)THPVariable_new_ones, METH_VARARGS | METH_KEYWORDS, NULL},
{"new_tensor", (PyCFunction)THPVariable_new_tensor, METH_VARARGS | METH_KEYWORDS, NULL},
{"new_zeros", (PyCFunction)THPVariable_new_zeros, METH_VARARGS | METH_KEYWORDS, NULL},
{"numpy", (PyCFunction)THPVariable_numpy, METH_NOARGS, NULL},
{"record_stream", (PyCFunction)THPVariable_record_stream, METH_O, NULL},
{"requires_grad_", (PyCFunction)THPVariable_requires_grad_, METH_VARARGS | METH_KEYWORDS, NULL},
{"short", (PyCFunction)THPVariable_short, METH_NOARGS, NULL},
{"size", (PyCFunction)THPVariable_size, METH_VARARGS | METH_KEYWORDS, NULL},
{"storage", (PyCFunction)THPVariable_storage, METH_NOARGS, NULL},
{"storage_offset", (PyCFunction)THPVariable_storage_offset, METH_NOARGS, NULL},
{"storage_type", (PyCFunction)THPVariable_storage_type, METH_NOARGS, NULL},
{"stride", (PyCFunction)THPVariable_stride, METH_VARARGS | METH_KEYWORDS, NULL},
{"to", (PyCFunction)THPVariable_to, METH_VARARGS | METH_KEYWORDS, NULL},
{"tolist", (PyCFunction)THPVariable_tolist, METH_NOARGS, NULL},
{"type", (PyCFunction)THPVariable_type, METH_VARARGS | METH_KEYWORDS, NULL},
${py_method_defs}
{NULL}
};
}} // namespace torch::autograd