pytorch/torch/csrc/Device.cpp
Peter Goldsborough 372d1d6735
Create ATen tensors via TensorOptions (#7869)
* Created TensorOptions

Storing the type in TensorOptions to solve the Variable problem

Created convenience creation functions for TensorOptions and added tests

Converted zeros to TensorOptions

Converted rand to TensorOptions

Fix codegen for TensorOptions and multiple arguments

Put TensorOptions convenience functions into torch namespace too

All factory functions except *_like support TensorOptions

Integrated with recent JIT changes

Support *_like functions

Fix in place modification

Some cleanups and fixes

Support sparse_coo_tensor

Fix bug in Type.cpp

Fix .empty calls in C++ API

Fix bug in Type.cpp

Trying to fix device placement

Make AutoGPU CPU compatible

Remove some auto_gpu.h uses

Fixing some headers

Fix some remaining CUDA/AutoGPU issues

Fix some AutoGPU uses

Fixes to dispatch_tensor_conversion

Reset version of new variables to zero

Implemented parsing device strings

Random fixes to tests

Self review cleanups

flake8

Undo changes to variable.{h,cpp} because they fail on gcc7.2

Add [cuda] tag to tensor_options_cuda.cpp

Move AutoGPU::set_index_from into .cpp file because Windows is stupid and sucks

Fix linker error in AutoGPU.cpp

Fix bad merge conflict in native_functions.yaml

Fixed caffe2/contrib/aten

Fix new window functions added to TensorFactories.cpp

* Removed torch::TensorOptions

Added code to generate wrapper functions for factory methods

Add implicit constructor from Backend to TensorOptions

Remove Var() from C++ API and use torch:: functions

Use torch:: functions more subtly in C++ API

Make AutoGPU::set_device more exception safe

Check status directly in DynamicCUDAHooksInterface

Rename AutoGPU to DeviceGuard

Removed set_requires_grad from python_variables.h and warn appropriately in Variable::set_requires_grad

remove python_default_init: self.type()

Add back original factory functions, but with deprecation warnings

Disable DeviceGuard for a couple functions in ATen

Remove print statement

Fix DeviceGuard construction from undefined tensor

Fixing CUDA device compiler issues

Moved as many methods as possible into header files

Dont generate python functions for deprecated factories

Remove merge conflict artefact

Fix tensor_options_cuda.cpp

Fix set_requires_grad not being checked

Fix tensor_new.h

TEMPORARILY put some methods in .cpp files to see if it solves issues on windows and mac

Fix bug in DeviceGuard.h

Missing includes

TEMPORARILY moving a few more methods into .cpp to see if it fixes windows

Fixing linker errors

* Fix up SummaryOps to use new factories

Undo device agnostic behavior of DeviceGuard

Use -1 instead of optional for default device index

Also move DeviceGuard methods into header

Fixes around device index after optional -> int32_t switch

Fix use of DeviceGuard in new_with_tensor_copy

Fix tensor_options.cpp

* Fix Type::copy(

* Remove test_non_float_params from ONNX tests

* Set requires_grad=False in ONNX tests that use ints

* Put layout/dtype/device on Tensor

* Post merge fixes

* Change behavior of DeviceGuard to match AutoGPU

* Fix C++ API integration tests

* Fix flip functions
2018-06-16 00:40:35 -07:00

221 lines
6.9 KiB
C++

#include "torch/csrc/Device.h"
#include "torch/csrc/Exceptions.h"
#include "torch/csrc/utils/object_ptr.h"
#include "torch/csrc/utils/python_arg_parser.h"
#include "torch/csrc/utils/python_strings.h"
#include "torch/csrc/utils/pybind.h"
#include <ATen/Device.h>
#include <ATen/Error.h>
#include <cstring>
#include <structmember.h>
#include <sstream>
PyObject *THPDevice_New(const at::Device& device)
{
auto type = (PyTypeObject*)&THPDeviceType;
auto self = THPObjectPtr{type->tp_alloc(type, 0)};
if (!self) throw python_error();
auto self_ = reinterpret_cast<THPDevice*>(self.get());
self_->device = device;
return self.release();
}
PyObject *THPDevice_repr(THPDevice *self)
{
std::ostringstream oss;
oss << "device(type=\'" << self->device.type() << "\'";
if (self->device.has_index()) {
oss << ", index=" << self->device.index();
}
oss << ")";
return THPUtils_packString(oss.str().c_str());
}
PyObject *THPDevice_str(THPDevice *self)
{
std::ostringstream oss;
oss << self->device;
return THPUtils_packString(oss.str().c_str());
}
PyObject *THPDevice_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs)
{
HANDLE_TH_ERRORS
static torch::PythonArgParser parser({
"Device(Device device)",
"Device(std::string type, int64_t? index=-1)"
});
torch::ParsedArgs<2> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
auto device = r.device(0);
return THPDevice_New(device);
} else if (r.idx == 1) {
auto as_device = r.device(0); // this works, because device can take strings
auto device_type = r.string(0);
if (as_device.has_index()) {
throw std::runtime_error("type (string) must not include an index because index "
"was passed explicitly: " + device_type);
}
int32_t device_index = -1;
if (!r.isNone(1)) {
device_index = r.toInt64(1);
// -1 is allowed in ATen/C++, to mean the default device, but not in
// Python.
AT_CHECK(device_index >= 0, "Device index must not be negative");
}
at::Device device(as_device.type(), device_index);
return THPDevice_New(device);
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject *THPDevice_type(THPDevice *self)
{
HANDLE_TH_ERRORS
std::ostringstream oss;
oss << self->device.type();
return THPUtils_packString(oss.str().c_str());
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject *THPDevice_index(THPDevice *self)
{
HANDLE_TH_ERRORS
if (self->device.has_index()) {
return THPUtils_packInt64(self->device.index());
} else {
Py_RETURN_NONE;
}
END_HANDLE_TH_ERRORS
}
PyObject *THPDevice_rc(PyObject *a, PyObject *b, int op) {
HANDLE_TH_ERRORS
if (!THPDevice_Check(a) || !THPDevice_Check(b)) {
// Py_RETURN_NOTIMPLEMENTED not in python 2.
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
THPDevice *da = reinterpret_cast<THPDevice*>(a);
THPDevice *db = reinterpret_cast<THPDevice*>(b);
switch(op) {
case Py_EQ:
if (da->device == db->device) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
case Py_NE:
if (da->device == db->device) {
Py_RETURN_FALSE;
} else {
Py_RETURN_TRUE;
}
case Py_LT:
case Py_LE:
case Py_GT:
case Py_GE:
throw torch::TypeError("comparison not implemented");
default:
throw torch::TypeError("unexpected comparison op");
}
END_HANDLE_TH_ERRORS
}
PyObject *THPDevice_reduce(THPDevice *self)
{
HANDLE_TH_ERRORS
auto ret = THPObjectPtr{PyTuple_New(2)};
if (!ret) throw python_error();
py::object torch_module = py::module::import("torch");
py::object torch_device = torch_module.attr("device");
PyTuple_SET_ITEM(ret.get(), 0, torch_device.release().ptr());
THPObjectPtr args;
std::ostringstream oss;
oss << self->device.type();
if (self->device.has_index()) {
args = THPObjectPtr{Py_BuildValue("(si)", oss.str().c_str(), self->device.index())};
} else {
args = THPObjectPtr{Py_BuildValue("(s)", oss.str().c_str())};
}
if (!args) throw python_error();
PyTuple_SET_ITEM(ret.get(), 1, args.release());
return ret.release();
END_HANDLE_TH_ERRORS
}
typedef PyObject *(*getter)(PyObject *, void *);
static struct PyGetSetDef THPDevice_properties[] = {
{"type", (getter)THPDevice_type, nullptr, nullptr, nullptr},
{"index", (getter)THPDevice_index, nullptr, nullptr, nullptr},
{nullptr}
};
static PyMethodDef THPDevice_methods[] = {
{"__reduce__", (PyCFunction)THPDevice_reduce, METH_NOARGS, nullptr},
{NULL} /* Sentinel */
};
PyTypeObject THPDeviceType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"torch.device", /* tp_name */
sizeof(THPDevice), /* tp_basicsize */
0, /* tp_itemsize */
0, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_reserved */
(reprfunc)THPDevice_repr, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
(reprfunc)THPDevice_str, /* tp_str */
0, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT, /* tp_flags */
nullptr, /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
(richcmpfunc)THPDevice_rc, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
THPDevice_methods, /* tp_methods */
0, /* tp_members */
THPDevice_properties, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
0, /* tp_init */
0, /* tp_alloc */
THPDevice_pynew, /* tp_new */
};
void THPDevice_init(PyObject *module)
{
if (PyType_Ready(&THPDeviceType) < 0) {
throw python_error();
}
Py_INCREF(&THPDeviceType);
if (PyModule_AddObject(module, "device", (PyObject *)&THPDeviceType) != 0) {
throw python_error();
}
}