pytorch/tools/autograd/templates/python_torch_functions.cpp
Edward Z. Yang f064c5aa33
Expunge all occurrences of torch._C._VariableFunctions (#5525)
Some of the call-sites now look a little hokey with this
removed, saving that for another patch.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2018-03-02 12:19:44 -05:00

160 lines
6.1 KiB
C++

// ${generated_comment}
// Python bindings for torch.* functions implemented through ATen.
//
// The functions are bound as static methods on a class
// torch._C._VariableFunctions which is also aliased as Variable._torch
// and also copied into 'torch' module.
#include <Python.h>
#include "torch/csrc/Exceptions.h"
#include "torch/csrc/autograd/python_variable.h"
#include "torch/csrc/autograd/utils/wrap_outputs.h"
#include "torch/csrc/utils/python_arg_parser.h"
#include "torch/csrc/utils/tensor_new.h"
#include "torch/csrc/utils/tensor_numpy.h"
#include "python_torch_functions_dispatch.h"
using at::Tensor;
using at::Scalar;
using at::ScalarType;
using at::Backend;
using namespace torch::autograd::utils;
namespace torch { namespace autograd {
static Tensor set_requires_grad(Tensor self, bool requires_grad) {
as_variable_ref(self).set_requires_grad(requires_grad);
return self;
}
static void check_out_dtype_matches(Tensor result, const at::Type &type) {
if (result.type() != type) {
at::runtime_error("dtype corresponding to %s does not match type of out parameter (%s)",
type.toString(), result.type().toString());
}
}
static Tensor dispatch_clamp(const Tensor & self, Scalar min, Scalar max) {
AutoNoGIL no_gil;
AutoGPU auto_gpu(self);
return self.clamp(min, max);
}
static Tensor dispatch_clamp_min(const Tensor & self, Scalar min) {
AutoNoGIL no_gil;
AutoGPU auto_gpu(self);
return self.clamp_min(min);
}
static Tensor dispatch_clamp_max(const Tensor & self, Scalar max) {
AutoNoGIL no_gil;
AutoGPU auto_gpu(self);
return self.clamp_max(max);
}
// The Python clamp() syntax has to be mapped to one of three C++ functions
static PyObject * THPVariable_clamp(PyObject* module, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"clamp(Tensor input, Scalar min=None, Scalar max=None)",
});
ParsedArgs<3> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (!r.isNone(1) && !r.isNone(2)) {
return THPVariable_Wrap(dispatch_clamp(r.tensor(0), r.scalar(1), r.scalar(2)));
} else if (!r.isNone(1)) {
return THPVariable_Wrap(dispatch_clamp_min(r.tensor(0), r.scalar(1)));
} else if (!r.isNone(2)) {
return THPVariable_Wrap(dispatch_clamp_max(r.tensor(0), r.scalar(2)));
} else {
throw std::runtime_error("At least one of 'min' or 'max' must not be None");
}
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_from_numpy(PyObject* module, PyObject* arg)
{
HANDLE_TH_ERRORS
auto data = torch::utils::tensor_from_numpy(arg);
return THPVariable_Wrap(make_variable(std::move(data), /*requires_grad=*/false));
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_tensor(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
return THPVariable_Wrap(torch::utils::new_tensor(default_type(), args, kwargs));
END_HANDLE_TH_ERRORS
}
// generated methods start here
${py_methods}
static PyMethodDef torch_functions[] = {
{"clamp", (PyCFunction)THPVariable_clamp, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"dsmm", (PyCFunction)THPVariable_mm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"from_numpy", (PyCFunction)THPVariable_from_numpy, METH_STATIC | METH_O, NULL},
{"hsmm", (PyCFunction)THPVariable_hspmm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"saddmm", (PyCFunction)THPVariable_sspaddmm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"spmm", (PyCFunction)THPVariable_mm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"tensor", (PyCFunction)THPVariable_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
${py_method_defs}
{NULL}
};
static PyTypeObject THPVariableFunctions = {
PyVarObject_HEAD_INIT(NULL, 0)
"torch._C._VariableFunctions", /* tp_name */
0, /* tp_basicsize */
0, /* tp_itemsize */
0, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_reserved */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
0, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT, /* tp_flags */
NULL, /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
torch_functions, /* tp_methods */
0, /* tp_members */
0, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
0, /* tp_init */
0, /* tp_alloc */
0 /* tp_new */
};
void initTorchFunctions(PyObject* module) {
if (PyType_Ready(&THPVariableFunctions) < 0) {
throw python_error();
}
Py_INCREF(&THPVariableFunctions);
if (PyModule_AddObject(module, "_VariableFunctions", (PyObject*)&THPVariableFunctions) < 0) {
throw python_error();
}
}
}} // namespace torch::autograd