pytorch/torch/csrc/generic/serialization.cpp
Bulent Abali afa5d0823b Fixes big endian arch bugs. (#26383)
Summary:
Serialization.cpp fails on big endian machines.
This patch fixes the endian bugs and also makes the pytorch
model files portable across different endian architectures.
x86 generated model file can be read on s390 arch.

First problem, is serialization.cpp forgets to convert "size" value
of the storage elements to the native byte order.
torch.load throws an assertion as a result
(see the first stack trace below).

Second problem is when it reads the model from storage (doRead)
it decodes values to little endian which is the wrong order
on a big endian machine.  The decode should be
to THP_nativeByteOrder() instead
	(see the model dump below)
```loaded_model = torch.load( opt.model_file, map_location=torch.device("cpu"))
File "/usr/local/lib/python3.6/dist-packages/torch/serialization.py", line 422, in load
return _load(f, map_location, pickle_module, **pickle_load_args)
File "/usr/local/lib/python3.6/dist-packages/torch/serialization.py", line 616, in _load
deserialized_objects[key]._set_from_file(f, offset, f_should_read_directly)
RuntimeError: storage has wrong size: expected 2305843009213693952 got 32
	(the very long number is actually 32 in the wrong endianness)
```

Model file load on x86 (correct output)
```>>> import torch
>>> torch.load('400f2k_best.model', map_location=torch.device("cpu"))
{'epoch': 24, 'model_type': 'emb_aec', 'classifier_model': OrderedDict([('model.0.weight', tensor([[ 2.4608e-01, -1.1174e-01, -1.0854e-01,  4.0124e-01, -1.5261e-02,
         -1.2206e-01,  1.3229e-01, -1.2615e-01, -5.2773e-01,  2.6333e-01,
         -3.1462e-03, -1.4902e-01,  9.8545e-02, -1.5789e-01, -2.2625e-01,
         -1.0776e-01, -9.0895e-02, -3.8530e-01,  9.1152e-01, -3.9720e-01,
         -8.5848e-01, -4.7837e-02, -1.5178e-01,  8.5023e-02,  1.5013e-01,
         -9.9294e-02, -2.7422e-01, -4.3986e-01, -4.4297e-01, -3.9570e-01,
```

Model file load on s390x (wrong endianness; notice the exponents)
```>>> import torch
>>> torch.load( "400f2k_best.model", map_location=torch.device("cpu"))
{'epoch': 24, 'model_type': 'emb_aec', 'classifier_model': OrderedDict([('model.0.weight', tensor([[ 9.2780e+21, -9.7722e-11,  4.1350e+33,  7.782e+34,  4.2056e-31,
          9.0784e+18,  1.1846e-32,  3.3320e-32, -4.8288e-28, -7.2679e+12,
          1.5379e-16, -5.2604e+12, -4.7240e+17,  4.6092e-21, -1.8360e-20,
         -2.7712e-31,  1.4548e-16, -2.5089e-27,  7.9094e-10,  7.1977e+34,
          1.1930e+26,  8.4536e+15,  2.7757e+23, -5.8455e-10, -1.5611e+09,
         -1.1311e-23,  6.6451e+19, -2.0970e+20,  3.4878e-19, -1.0857e-12,
          7.8098e+22,  5.3998e-35],
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26383

Differential Revision: D17480891

fbshipit-source-id: f40569c7b9c4a1935dceb41f1a2508ce21ea3491
2019-09-19 19:58:02 -07:00

140 lines
4.8 KiB
C++

#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "torch/csrc/generic/serialization.cpp"
#else
#ifdef THC_GENERIC_FILE
#include <c10/cuda/CUDAGuard.h>
#endif
template <class io>
void THPStorage_(writeFileRaw)(THWStorage *self, io fd)
{
#ifdef THC_GENERIC_FILE
c10::cuda::CUDAGuard guard(self->device());
#endif
scalar_t *data;
int64_t size = THWStorage_(size)(LIBRARY_STATE self);
#ifndef THC_GENERIC_FILE
data = THWStorage_(data)(LIBRARY_STATE self);
#else
std::unique_ptr<char[]> cpu_data(new char[size * sizeof(scalar_t)]);
data = (scalar_t*)cpu_data.get();
THCudaCheck(cudaMemcpy(data, THWStorage_(data)(LIBRARY_STATE self), size * sizeof(scalar_t), cudaMemcpyDeviceToHost));
#endif
if (THP_nativeByteOrder() == THPByteOrder::THP_LITTLE_ENDIAN)
doWrite(fd, &size, sizeof(int64_t));
else {
int64_t nsize; // convert big endian cpu to little endian storage
THP_encodeInt64Buffer((uint8_t*)&nsize, (const int64_t *)&size, THPByteOrder::THP_LITTLE_ENDIAN, 1);
doWrite(fd, &nsize, sizeof(int64_t));
}
// fast track for bytes and little endian
if (sizeof(scalar_t) == 1 || THP_nativeByteOrder() == THPByteOrder::THP_LITTLE_ENDIAN) {
doWrite(fd, data, sizeof(scalar_t) * size);
} else {
int64_t buffer_size = std::min(size, (int64_t)5000);
std::unique_ptr<uint8_t[]> le_buffer(new uint8_t[buffer_size * sizeof(scalar_t)]);
for (int64_t i = 0; i < size; i += buffer_size) {
size_t to_convert = std::min(size - i, buffer_size);
if (sizeof(scalar_t) == 2) {
THP_encodeInt16Buffer((uint8_t*)le_buffer.get(),
(const int16_t*)data + i,
THPByteOrder::THP_LITTLE_ENDIAN,
to_convert);
} else if (sizeof(scalar_t) == 4) {
THP_encodeInt32Buffer((uint8_t*)le_buffer.get(),
(const int32_t*)data + i,
THPByteOrder::THP_LITTLE_ENDIAN,
to_convert);
} else if (sizeof(scalar_t) == 8) {
THP_encodeInt64Buffer((uint8_t*)le_buffer.get(),
(const int64_t*)data + i,
THPByteOrder::THP_LITTLE_ENDIAN,
to_convert);
}
doWrite(fd, le_buffer.get(), to_convert * sizeof(scalar_t));
}
}
}
template void THPStorage_(writeFileRaw<int>)(THWStorage *self, int fd);
template void THPStorage_(writeFileRaw<PyObject*>)(THWStorage *self, PyObject* fd);
template <class io>
THWStorage * THPStorage_(readFileRaw)(io file, THWStorage *_storage)
{
#ifdef THC_GENERIC_FILE
c10::cuda::OptionalCUDAGuard guard;
if (_storage != nullptr) {
guard.set_device(_storage->device());
}
#endif
scalar_t *data;
int64_t size;
doRead(file, &size, sizeof(int64_t));
if (THP_nativeByteOrder() == THPByteOrder::THP_BIG_ENDIAN) {
int64_t nsize; // convert little endian storage to big endian cpu
nsize = size;
THP_decodeInt64Buffer(&size, (const uint8_t*)&nsize, THP_nativeByteOrder(), 1);
}
THWStoragePtr storage;
if (_storage == nullptr) {
storage = THWStorage_(newWithSize)(LIBRARY_STATE size);
} else {
THPUtils_assert(THWStorage_(size)(LIBRARY_STATE _storage) == size,
"storage has wrong size: expected %ld got %ld",
size, THWStorage_(size)(LIBRARY_STATE _storage));
storage = _storage;
}
#ifndef THC_GENERIC_FILE
data = THWStorage_(data)(LIBRARY_STATE storage);
#else
std::unique_ptr<char[]> cpu_data(new char[size * sizeof(scalar_t)]);
data = (scalar_t*)cpu_data.get();
#endif
// fast track for bytes and little endian
if (sizeof(scalar_t) == 1 || THP_nativeByteOrder() == THPByteOrder::THP_LITTLE_ENDIAN) {
doRead(file, data, sizeof(scalar_t) * THWStorage_(size)(LIBRARY_STATE storage));
} else {
int64_t buffer_size = std::min(size, (int64_t)5000);
std::unique_ptr<uint8_t[]> le_buffer(new uint8_t[buffer_size * sizeof(scalar_t)]);
for (int64_t i = 0; i < size; i += buffer_size) {
size_t to_convert = std::min(size - i, buffer_size);
doRead(file, le_buffer.get(), sizeof(scalar_t) * to_convert);
if (sizeof(scalar_t) == 2) {
THP_decodeInt16Buffer((int16_t*)data + i,
le_buffer.get(),
THP_nativeByteOrder(),
to_convert);
} else if (sizeof(scalar_t) == 4) {
THP_decodeInt32Buffer((int32_t*)data + i,
le_buffer.get(),
THP_nativeByteOrder(),
to_convert);
} else if (sizeof(scalar_t) == 8) {
THP_decodeInt64Buffer((int64_t*)data + i,
le_buffer.get(),
THP_nativeByteOrder(),
to_convert);
}
}
}
#ifdef THC_GENERIC_FILE
THCudaCheck(cudaMemcpy(THWStorage_(data)(LIBRARY_STATE storage), data, size * sizeof(scalar_t), cudaMemcpyHostToDevice));
#endif
return storage.release();
}
template THWStorage* THPStorage_(readFileRaw<int>)(int fd, THWStorage* storage);
template THWStorage* THPStorage_(readFileRaw<PyObject*>)(PyObject* fd, THWStorage* storage);
#endif