mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
This PR serves two purposes:
1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general,
2. Add serialization to the ONNX/PyTorch proto format.
This is currently a rough prototype I coded up today, to get quick feedback.
For this I propose the following serialization interface within the C++ API:
```cpp
namespace torch { namespace serialize {
class Reader {
public:
virtual ~Reader() = default;
virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0;
virtual void finish() { }
};
class Writer {
public:
virtual ~Reader() = default;
virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0;
virtual void finish() { }
};
}} // namespace torch::serialize
```
There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to:
1. Provide a cereal-less serialization forward that we can ship and iterate on going forward,
2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft.
The user-facing API is (conceptually):
```cpp
void torch::save(const Module& module, Writer& writer);
void torch::save(const Optimizer& optimizer, Writer& writer);
void torch::read(Module& module, Reader& reader);
void torch::read(Optimizer& optimizer, Reader& reader);
```
with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader`
ebetica ezyang zdevito dzhulgakov
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11619
Differential Revision: D9984664
Pulled By: goldsborough
fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
68 lines
1.3 KiB
C++
68 lines
1.3 KiB
C++
#pragma once
|
|
|
|
#include <torch/nn/cloneable.h>
|
|
#include <torch/tensor.h>
|
|
|
|
#include <cstdio>
|
|
#include <cstdlib>
|
|
#include <string>
|
|
#include <utility>
|
|
|
|
#ifndef WIN32
|
|
#include <unistd.h>
|
|
#endif
|
|
|
|
namespace torch {
|
|
namespace test {
|
|
|
|
// Lets you use a container without making a new class,
|
|
// for experimental implementations
|
|
class SimpleContainer : public nn::Cloneable<SimpleContainer> {
|
|
public:
|
|
void reset() override {}
|
|
|
|
template <typename ModuleHolder>
|
|
ModuleHolder add(
|
|
ModuleHolder module_holder,
|
|
std::string name = std::string()) {
|
|
return Module::register_module(std::move(name), module_holder);
|
|
}
|
|
};
|
|
|
|
inline bool pointer_equal(at::Tensor first, at::Tensor second) {
|
|
return first.data<float>() == second.data<float>();
|
|
}
|
|
|
|
#ifdef WIN32
|
|
struct TempFile {
|
|
TempFile() : filename_(std::tmpnam(nullptr)) {}
|
|
const std::string& str() const {
|
|
return filename_;
|
|
}
|
|
std::string filename_;
|
|
};
|
|
#else
|
|
struct TempFile {
|
|
TempFile() {
|
|
// http://pubs.opengroup.org/onlinepubs/009695399/functions/mkstemp.html
|
|
char filename[] = "/tmp/fileXXXXXX";
|
|
fd_ = mkstemp(filename);
|
|
AT_CHECK(fd_ != -1, "Error creating tempfile");
|
|
filename_.assign(filename);
|
|
}
|
|
|
|
~TempFile() {
|
|
close(fd_);
|
|
}
|
|
|
|
const std::string& str() const {
|
|
return filename_;
|
|
}
|
|
|
|
std::string filename_;
|
|
int fd_;
|
|
};
|
|
#endif
|
|
} // namespace test
|
|
} // namespace torch
|