mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
This PR serves two purposes:
1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general,
2. Add serialization to the ONNX/PyTorch proto format.
This is currently a rough prototype I coded up today, to get quick feedback.
For this I propose the following serialization interface within the C++ API:
```cpp
namespace torch { namespace serialize {
class Reader {
public:
virtual ~Reader() = default;
virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0;
virtual void finish() { }
};
class Writer {
public:
virtual ~Reader() = default;
virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0;
virtual void finish() { }
};
}} // namespace torch::serialize
```
There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to:
1. Provide a cereal-less serialization forward that we can ship and iterate on going forward,
2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft.
The user-facing API is (conceptually):
```cpp
void torch::save(const Module& module, Writer& writer);
void torch::save(const Optimizer& optimizer, Writer& writer);
void torch::read(Module& module, Reader& reader);
void torch::read(Optimizer& optimizer, Reader& reader);
```
with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader`
ebetica ezyang zdevito dzhulgakov
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11619
Differential Revision: D9984664
Pulled By: goldsborough
fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
56 lines
1.4 KiB
C++
56 lines
1.4 KiB
C++
#include <torch/optim/optimizer.h>
|
|
|
|
#include <torch/nn/cursor.h>
|
|
#include <torch/tensor.h>
|
|
|
|
#include <string>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
namespace torch {
|
|
namespace optim {
|
|
namespace detail {
|
|
OptimizerBase::OptimizerBase(std::vector<Tensor> parameters)
|
|
: parameters_(std::move(parameters)) {}
|
|
|
|
OptimizerBase::OptimizerBase(const ParameterCursor& cursor) {
|
|
add_parameters(cursor);
|
|
}
|
|
|
|
void OptimizerBase::add_parameters(const std::vector<Tensor>& parameters) {
|
|
parameters_.insert(parameters_.end(), parameters.begin(), parameters.end());
|
|
}
|
|
|
|
void OptimizerBase::add_parameters(const ParameterCursor& cursor) {
|
|
std::vector<Tensor> tensors(cursor.size());
|
|
cursor.map(tensors.begin(), [](const Tensor& tensor) { return tensor; });
|
|
add_parameters(tensors);
|
|
}
|
|
|
|
void OptimizerBase::zero_grad() {
|
|
for (auto& parameter : parameters_) {
|
|
if (parameter.grad().defined()) {
|
|
parameter.grad().detach_();
|
|
parameter.grad().zero_();
|
|
}
|
|
}
|
|
}
|
|
|
|
const std::vector<Tensor>& OptimizerBase::parameters() const noexcept {
|
|
return parameters_;
|
|
}
|
|
|
|
std::vector<Tensor>& OptimizerBase::parameters() noexcept {
|
|
return parameters_;
|
|
}
|
|
|
|
size_t OptimizerBase::size() const noexcept {
|
|
return parameters_.size();
|
|
}
|
|
|
|
void OptimizerBase::save(serialize::OutputArchive& archive) const {}
|
|
void OptimizerBase::load(serialize::InputArchive& archive) {}
|
|
} // namespace detail
|
|
} // namespace optim
|
|
} // namespace torch
|