mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62030 Remove dtype tracking from Python Storage interface, remove all the different `<type>Storage` classes except for `ByteStorage`, and update serialization accordingly, while maintaining as much FC/BC as possible Fixes https://github.com/pytorch/pytorch/issues/47442 * **THE SERIALIZATION FORMAT IS FULLY FC/BC.** We worked very hard to make sure this is the case. We will probably want to break FC at some point to make the serialization structure of tensors make more sense, but not today. * There is now only a single torch.ByteStorage class. Methods like `Tensor.set_` no longer check that the dtype of storage is appropriate. * As we no longer know what dtype of a storage is, we've **removed** the size method from Storage, replacing it with nbytes. This is to help catch otherwise silent errors where you confuse number of elements with number of bytes. * `Storage._new_shared` takes a `nbytes` kwarg and will reject previous positional only calls. `Storage._new_with_file` and `_set_from_file` require explicit element size arguments. * It's no longer possible to convert storages to different types using the float/double/etc methods. Instead, do the conversion using a tensor. * It's no longer possible to allocate a typed storage directly using FloatStorage/DoubleStorage/etc constructors. Instead, construct a tensor and extract its storage. The classes still exist but they are used purely for unpickling. * The preexisting serialization format stores dtype with storage, and in fact this dtype is used to determine the dtype of the tensor overall. To accommodate this case, we introduce a new TypedStorage concept that exists only during unpickling time which is used to temporarily store the dtype so we can construct a tensor. **If you overrode the handling of pickling/unpickling, you MUST add handling for TypedStorage** or your serialization code will degrade to standard file-based serialization. Original pull request: https://github.com/pytorch/pytorch/pull/59671 Reviewed By: soulitzer, ngimel Differential Revision: D29466819 Pulled By: ezyang fbshipit-source-id: 4a14e5d3c2b08e06e558683d97f7378a3180b00e
413 lines
15 KiB
C++
413 lines
15 KiB
C++
#include <ATen/ATen.h>
|
|
#include <ATen/MapAllocator.h>
|
|
#include <torch/csrc/utils/pycfunction_helpers.h>
|
|
#include <torch/csrc/utils/python_numbers.h>
|
|
|
|
#ifdef USE_CUDA
|
|
#include <cuda_runtime.h>
|
|
#endif
|
|
|
|
#ifdef _MSC_VER
|
|
#define LSEEK _lseeki64
|
|
#else
|
|
#define LSEEK lseek
|
|
#endif
|
|
|
|
static PyObject * THPStorage_(nbytes)(PyObject *_self, PyObject *noargs)
|
|
{
|
|
HANDLE_TH_ERRORS
|
|
auto self = (THPStorage*)_self;
|
|
return THPUtils_packUInt64(self->cdata->nbytes());
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
static PyObject * THPStorage_(dataPtr)(PyObject *_self, PyObject *noargs)
|
|
{
|
|
HANDLE_TH_ERRORS
|
|
auto self = (THPStorage*)_self;
|
|
return PyLong_FromVoidPtr(THWStorage_(data)(LIBRARY_STATE self->cdata));
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
static PyObject * THPStorage_(copy_)(PyObject *self, PyObject *args, PyObject *kwargs)
|
|
{
|
|
HANDLE_TH_ERRORS
|
|
return THPStorageCopyMethod(THWStorage_(copy_functions), self, args, kwargs);
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
static PyObject * THPStorage_(isPinned)(PyObject *_self, PyObject *noargs)
|
|
{
|
|
HANDLE_TH_ERRORS
|
|
auto self = (THPStorage*)_self;
|
|
#if defined(USE_CUDA)
|
|
return PyBool_FromLong(at::globalContext().isPinnedPtr(THWStorage_(data)(LIBRARY_STATE self->cdata)));
|
|
#else
|
|
Py_RETURN_FALSE;
|
|
#endif
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
static PyObject * THPStorage_(elementSize)(PyObject *_self, PyObject *noargs)
|
|
{
|
|
HANDLE_TH_ERRORS
|
|
auto self = (THPStorage*)_self;
|
|
return THPUtils_packInt64(THWStorage_(elementSize)(LIBRARY_STATE_NOARGS));
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
static PyObject * THPStorage_(new)(PyObject *_self, PyObject *noargs)
|
|
{
|
|
HANDLE_TH_ERRORS
|
|
auto self = (THPStorage*)_self;
|
|
THWStoragePtr new_storage(THWStorage_(new)(LIBRARY_STATE_NOARGS));
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
PyObject *_ret = THPStorage_(New)(new_storage);
|
|
new_storage.release();
|
|
return _ret;
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
static PyObject * THPStorage_(resize_)(PyObject *_self, PyObject *number_arg)
|
|
{
|
|
HANDLE_TH_ERRORS
|
|
auto self = (THPStorage*)_self;
|
|
THPUtils_assert(THPUtils_checkLong(number_arg), "resize_ expects an int, "
|
|
"but got %s", THPUtils_typename(number_arg));
|
|
int64_t newsize = THPUtils_unpackLong(number_arg);
|
|
THWStorage_(resizeBytes)(
|
|
LIBRARY_STATE self->cdata, newsize * sizeof(scalar_t));
|
|
Py_INCREF(self);
|
|
return (PyObject*)self;
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
static PyObject * THPStorage_(fill_)(PyObject *_self, PyObject *number_arg)
|
|
{
|
|
HANDLE_TH_ERRORS
|
|
auto self = (THPStorage*)_self;
|
|
THPUtils_assert(THPUtils_(checkReal)(number_arg), "fill_ expects %s, "
|
|
"but got %s", THPUtils_typeTraits<scalar_t>::python_type_str,
|
|
THPUtils_typename(number_arg));
|
|
THWStorage_(fill)(LIBRARY_STATE self->cdata, THPUtils_(unpackReal)(number_arg));
|
|
Py_INCREF(self);
|
|
return (PyObject*)self;
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
#if !defined(THC_GENERIC_FILE)
|
|
static PyObject * THPStorage_(fromBuffer)(PyObject *_unused, PyObject *args, PyObject *keywds)
|
|
{
|
|
HANDLE_TH_ERRORS
|
|
PyObject *obj = nullptr;
|
|
const char* byte_order_str = nullptr;
|
|
Py_ssize_t count = -1, offset = 0;
|
|
PyObject* dtype_obj = nullptr;
|
|
c10::ScalarType scalar_type = at::kByte;
|
|
Py_buffer buffer = {};
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,clang-diagnostic-writable-strings)
|
|
static char *kwlist[] = {"buffer", "byte_order", "count", "offset", "dtype", nullptr};
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
const char* argtypes;
|
|
argtypes = "O|snnO";
|
|
|
|
if (!PyArg_ParseTupleAndKeywords(args, keywds, argtypes, kwlist,
|
|
&obj, &byte_order_str, &count, &offset, &dtype_obj)) {
|
|
return nullptr;
|
|
}
|
|
TORCH_CHECK(dtype_obj != nullptr, "argument 'dtype' cannot be None");
|
|
TORCH_CHECK(THPDtype_Check(dtype_obj), "argument 'dtype' must be of type torch.dtype");
|
|
auto dtype = reinterpret_cast<THPDtype*>(dtype_obj);
|
|
scalar_type = dtype->scalar_type;
|
|
|
|
TORCH_CHECK(
|
|
(scalar_type == at::kByte) || (scalar_type == at::kChar) || (byte_order_str != nullptr),
|
|
"function missing required argument 'byte_order' (pos 2)");
|
|
size_t element_size = c10::elementSize(scalar_type);
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
torch::utils::THPByteOrder byte_order;
|
|
if (scalar_type != at::kByte && scalar_type != at::kChar) {
|
|
if (strcmp(byte_order_str, "native") == 0) {
|
|
byte_order = torch::utils::THP_nativeByteOrder();
|
|
} else if (strcmp(byte_order_str, "big") == 0) {
|
|
byte_order = torch::utils::THP_BIG_ENDIAN;
|
|
} else if (strcmp(byte_order_str, "little") == 0) {
|
|
byte_order = torch::utils::THP_LITTLE_ENDIAN;
|
|
} else {
|
|
PyErr_Format(PyExc_ValueError,
|
|
"invalid byte_order '%s' (expected 'big', 'little', or 'native')",
|
|
byte_order_str);
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
if (PyObject_GetBuffer(obj, &buffer, PyBUF_SIMPLE) < 0)
|
|
return nullptr;
|
|
|
|
if (offset < 0 || offset > buffer.len) {
|
|
PyErr_SetString(PyExc_ValueError, fmt::format(
|
|
"offset must be non-negative and no greater than buffer length ({}) , but got {}",
|
|
offset, buffer.len));
|
|
PyBuffer_Release(&buffer);
|
|
return nullptr;
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
size_t size_bytes;
|
|
if (count < 0) {
|
|
if ((buffer.len - offset) % element_size != 0) {
|
|
PyErr_SetString(PyExc_ValueError, fmt::format(
|
|
"buffer size ({}) must be a multiple of element size ({})",
|
|
buffer.len, element_size));
|
|
PyBuffer_Release(&buffer);
|
|
return nullptr;
|
|
}
|
|
size_bytes = buffer.len - offset;
|
|
count = size_bytes / element_size;
|
|
} else {
|
|
size_bytes = count * element_size;
|
|
}
|
|
|
|
if (offset + (count * (Py_ssize_t)element_size) > buffer.len) {
|
|
PyErr_SetString(PyExc_ValueError, fmt::format(
|
|
"buffer has only {} elements after offset {}, but specified a size of {}",
|
|
buffer.len - offset, offset, count));
|
|
PyBuffer_Release(&buffer);
|
|
return nullptr;
|
|
}
|
|
|
|
uint8_t* src = (uint8_t*) buffer.buf;
|
|
THWStorage* storage = THWStorage_(newWithSize)(size_bytes);
|
|
|
|
if (scalar_type == at::kByte || scalar_type == at::kChar) {
|
|
memcpy(THWStorage_(data)(storage), src + offset, count);
|
|
} else if (scalar_type == at::kBool) {
|
|
// Because of ASAN checks, that are failing in the THStorage.cpp whenever
|
|
// we are trying to get a value which is not 0 or 1, we have to manually
|
|
// convert original values to boolean ones.
|
|
torch::utils::THP_decodeBoolBuffer(
|
|
reinterpret_cast<bool*>(THWStorage_(data)(storage)),
|
|
src + offset, byte_order, count);
|
|
} else if (scalar_type == at::kShort) {
|
|
torch::utils::THP_decodeInt16Buffer(
|
|
reinterpret_cast<int16_t*>(THWStorage_(data)(storage)),
|
|
src + offset, byte_order, count);
|
|
} else if (scalar_type == at::kInt) {
|
|
torch::utils::THP_decodeInt32Buffer(
|
|
reinterpret_cast<int32_t*>(THWStorage_(data)(storage)),
|
|
src + offset, byte_order, count);
|
|
} else if (scalar_type == at::kLong) {
|
|
torch::utils::THP_decodeInt64Buffer(
|
|
reinterpret_cast<int64_t*>(THWStorage_(data)(storage)),
|
|
src + offset, byte_order, count);
|
|
} else if (scalar_type == at::kHalf) {
|
|
torch::utils::THP_decodeHalfBuffer(
|
|
reinterpret_cast<c10::Half*>(THWStorage_(data)(storage)),
|
|
src + offset, byte_order, count);
|
|
} else if (scalar_type == at::kBFloat16) {
|
|
torch::utils::THP_decodeBFloat16Buffer(
|
|
reinterpret_cast<c10::BFloat16*>(THWStorage_(data)(storage)),
|
|
src + offset, byte_order, count);
|
|
} else if (scalar_type == at::kFloat) {
|
|
torch::utils::THP_decodeFloatBuffer(
|
|
reinterpret_cast<float*>(THWStorage_(data)(storage)),
|
|
src + offset, byte_order, count);
|
|
} else if (scalar_type == at::kDouble) {
|
|
torch::utils::THP_decodeDoubleBuffer(
|
|
reinterpret_cast<double*>(THWStorage_(data)(storage)),
|
|
src + offset, byte_order, count);
|
|
} else if (scalar_type == at::kComplexFloat) {
|
|
torch::utils::THP_decodeComplexFloatBuffer(
|
|
reinterpret_cast<c10::complex<float>*>(THWStorage_(data)(storage)),
|
|
src + offset, byte_order, count);
|
|
} else if (scalar_type == at::kComplexDouble) {
|
|
torch::utils::THP_decodeComplexDoubleBuffer(
|
|
reinterpret_cast<c10::complex<double>*>(THWStorage_(data)(storage)),
|
|
src + offset, byte_order, count);
|
|
} else {
|
|
TORCH_CHECK(false, "Unknown type: ", scalar_type);
|
|
}
|
|
|
|
PyBuffer_Release(&buffer);
|
|
return (PyObject*)THPStorage_(New)(storage);
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
#endif
|
|
|
|
static PyObject * THPStorage_(fromFile)(PyObject *_unused, PyObject *args, PyObject *keywds)
|
|
{
|
|
HANDLE_TH_ERRORS
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
const char *filename;
|
|
Py_ssize_t nbytes = 0;
|
|
int shared = 0;
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,clang-diagnostic-writable-strings)
|
|
static char *kwlist[] = {"filename", "shared", "nbytes", nullptr};
|
|
if (!PyArg_ParseTupleAndKeywords(args, keywds, "s|in", kwlist,
|
|
&filename, &shared, &nbytes)) {
|
|
return nullptr;
|
|
}
|
|
if (shared)
|
|
shared = at::ALLOCATOR_MAPPED_SHARED;
|
|
THWStorage *storage = THWStorage_(newWithMapping)(LIBRARY_STATE filename, nbytes, shared);
|
|
return (PyObject*)THPStorage_(New)(storage);
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
PyObject * THPStorage_(writeFile)(PyObject *_self, PyObject *args)
|
|
{
|
|
HANDLE_TH_ERRORS
|
|
auto self = (THPStorage*)_self;
|
|
PyObject *file = PyTuple_GetItem(args, 0);
|
|
bool is_real_file = PyTuple_GetItem(args, 1) == Py_True;
|
|
bool save_size = PyTuple_GetItem(args, 2) == Py_True;
|
|
PyObject *element_size_obj = PyTuple_GET_ITEM(args, 3);
|
|
|
|
THPUtils_assert(element_size_obj != Py_None,
|
|
"_write_file: need to specify element size");
|
|
uint64_t element_size = THPUtils_unpackUInt64(element_size_obj);
|
|
|
|
if (!is_real_file) {
|
|
THPStorage_(writeFileRaw<PyObject*>)(self->cdata, file, save_size, element_size);
|
|
Py_RETURN_NONE;
|
|
}
|
|
|
|
int fd = PyObject_AsFileDescriptor(file);
|
|
THPUtils_assert(fd != -1, "_write_file couldn't retrieve a file descriptor "
|
|
"from given object");
|
|
THPStorage_(writeFileRaw)(self->cdata, fd, save_size, element_size);
|
|
Py_RETURN_NONE;
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
PyObject * THPStorage_(newWithFile)(PyObject *_unused, PyObject *args)
|
|
{
|
|
HANDLE_TH_ERRORS
|
|
TORCH_CHECK(PyTuple_Size(args) == 2,
|
|
"_new_with_file takes exactly two arguments");
|
|
PyObject *fd_obj = PyTuple_GetItem(args, 0);
|
|
int fd = PyObject_AsFileDescriptor(PyTuple_GetItem(args, 0));
|
|
THPUtils_assert(fd != -1, "_new_with_file couldn't retrieve a file "
|
|
"descriptor from given object");
|
|
PyObject *element_size_obj = PyTuple_GET_ITEM(args, 1);
|
|
THPUtils_assert(element_size_obj != Py_None,
|
|
"_new_with_file: need to specify element size");
|
|
uint64_t element_size = THPUtils_unpackUInt64(element_size_obj);
|
|
|
|
THWStorage *storage = THPStorage_(readFileRaw<int>)(fd, nullptr, element_size);
|
|
if (storage == nullptr)
|
|
return nullptr;
|
|
PyObject *result = THPStorage_(New)(storage);
|
|
return result;
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
static PyObject *THPStorage_(setFromFile)(PyObject *_self, PyObject *args)
|
|
{
|
|
HANDLE_TH_ERRORS
|
|
auto self = (THPStorage*)_self;
|
|
PyObject *file = PyTuple_GET_ITEM(args, 0);
|
|
PyObject *offset = PyTuple_GET_ITEM(args, 1);
|
|
bool is_real_file = PyTuple_GET_ITEM(args, 2) == Py_True;
|
|
|
|
PyObject *element_size_obj = PyTuple_GET_ITEM(args, 3);
|
|
|
|
THPUtils_assert(element_size_obj != Py_None,
|
|
"_set_from_file: need to specify element size");
|
|
uint64_t element_size = THPUtils_unpackUInt64(element_size_obj);
|
|
|
|
if (!is_real_file) {
|
|
// offset can be implemented with a call to the Python object's seek()
|
|
// but it is currently unnecessary to support this.
|
|
THPUtils_assert(offset == Py_None,
|
|
"_set_from_file: offset is NYI for filelike objects");
|
|
THWStorage *storage = THPStorage_(readFileRaw<PyObject*>)(file, self->cdata, element_size);
|
|
if (storage == nullptr) {
|
|
return nullptr;
|
|
}
|
|
Py_INCREF(self);
|
|
return (PyObject *) self;
|
|
}
|
|
|
|
// file is backed by a fd
|
|
const int fd = PyObject_AsFileDescriptor(file);
|
|
const auto fd_original_pos = LSEEK(fd, 0, SEEK_CUR);
|
|
if (offset != Py_None) {
|
|
LSEEK(fd, THPUtils_unpackLong(offset), SEEK_SET);
|
|
}
|
|
THPUtils_assert(fd != -1, "_set_from_file couldn't retrieve a file "
|
|
"descriptor from given object");
|
|
THWStorage *storage = THPStorage_(readFileRaw<int>)(fd, self->cdata, element_size);
|
|
if (storage == nullptr)
|
|
return nullptr;
|
|
Py_INCREF(self);
|
|
|
|
// the file descriptor is returned to original position and
|
|
// the file handle at python call-site needs updating to the
|
|
// advanced position
|
|
const auto fd_current_pos = LSEEK(fd, 0, SEEK_CUR);
|
|
LSEEK(fd, fd_original_pos, SEEK_SET);
|
|
const auto seek_return = PyObject_CallMethod(file, "seek", "Li", (long long)fd_current_pos, 0);
|
|
if (seek_return == nullptr) {
|
|
return nullptr;
|
|
}
|
|
Py_DECREF(seek_return);
|
|
|
|
return (PyObject *) self;
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
#ifdef THC_GENERIC_FILE
|
|
PyObject * THPStorage_(getDevice)(PyObject *_self, PyObject *noargs)
|
|
{
|
|
HANDLE_TH_ERRORS
|
|
auto self = (THPStorage*)_self;
|
|
return THPUtils_packInt32(THCStorage_(getDevice)(LIBRARY_STATE self->cdata));
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
#endif
|
|
|
|
PyObject * THPStorage_(_setCdata)(PyObject *_self, PyObject *new_cdata)
|
|
{
|
|
HANDLE_TH_ERRORS
|
|
auto self = (THPStorage*)_self;
|
|
THPUtils_assert(THPUtils_checkLong(new_cdata), "given an invalid argument to "
|
|
"_set_cdata - expected an int or long, but got %s",
|
|
THPUtils_typename(new_cdata));
|
|
THWStorage *ptr = (THWStorage*)PyLong_AsVoidPtr(new_cdata);
|
|
THWStorage_(retain)(LIBRARY_STATE ptr);
|
|
THWStorage_(free)(LIBRARY_STATE self->cdata);
|
|
self->cdata = ptr;
|
|
Py_INCREF(self);
|
|
return (PyObject*)self;
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
|
|
static PyMethodDef THPStorage_(methods)[] = {
|
|
{"copy_", castPyCFunctionWithKeywords(THPStorage_(copy_)),
|
|
METH_VARARGS | METH_KEYWORDS, nullptr},
|
|
{"element_size", THPStorage_(elementSize), METH_NOARGS, nullptr},
|
|
{"fill_", THPStorage_(fill_), METH_O, nullptr},
|
|
{"new", THPStorage_(new), METH_NOARGS, nullptr},
|
|
{"resize_", THPStorage_(resize_), METH_O, nullptr},
|
|
{"nbytes", THPStorage_(nbytes), METH_NOARGS, nullptr},
|
|
{"data_ptr", THPStorage_(dataPtr), METH_NOARGS, nullptr},
|
|
{"is_pinned", THPStorage_(isPinned), METH_NOARGS, nullptr},
|
|
{"_write_file", THPStorage_(writeFile), METH_VARARGS, nullptr},
|
|
{"_new_with_file", THPStorage_(newWithFile), METH_VARARGS | METH_STATIC, nullptr},
|
|
{"_set_from_file", THPStorage_(setFromFile), METH_VARARGS, nullptr},
|
|
#if !defined(THC_GENERIC_FILE)
|
|
{"from_buffer", castPyCFunctionWithKeywords(THPStorage_(fromBuffer)),
|
|
METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
|
|
#endif
|
|
{"from_file", castPyCFunctionWithKeywords(THPStorage_(fromFile)),
|
|
METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
|
|
#ifdef THC_GENERIC_FILE
|
|
{"get_device", THPStorage_(getDevice), METH_NOARGS, nullptr},
|
|
#endif
|
|
{"_set_cdata", THPStorage_(_setCdata), METH_O, nullptr},
|
|
{nullptr}
|
|
};
|