Protobuf serialization (#11619)

Summary:
This PR serves two purposes:

1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general,
2. Add serialization to the ONNX/PyTorch proto format.

This is currently a rough prototype I coded up today, to get quick feedback.

For this I propose the following serialization interface within the C++ API:

```cpp
namespace torch { namespace serialize {
class Reader {
 public:
  virtual ~Reader() = default;
  virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0;
  virtual void finish() { }
};

class Writer {
 public:
  virtual ~Reader() = default;
  virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0;
  virtual void finish() { }
};
}} // namespace torch::serialize
```

There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to:

1. Provide a cereal-less serialization forward that we can ship and iterate on going forward,
2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft.

The user-facing API is (conceptually):

```cpp
void torch::save(const Module& module, Writer& writer);
void torch::save(const Optimizer& optimizer, Writer& writer);
void torch::read(Module& module, Reader& reader);
void torch::read(Optimizer& optimizer, Reader& reader);
```

with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader`

ebetica ezyang zdevito dzhulgakov
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11619

Differential Revision: D9984664

Pulled By: goldsborough

fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
This commit is contained in:
Peter Goldsborough 2018-09-20 20:36:22 -07:00 committed by Facebook Github Bot
parent 30521a37ad
commit d712a71741
36 changed files with 889 additions and 822 deletions

View File

@ -124,7 +124,6 @@ cmake_dependent_option(
cmake_dependent_option(
USE_GLOO_IBVERBS "Use Gloo IB verbs for distributed. Only available if USE_GLOO is on." OFF
"USE_GLOO" OFF)
option(TORCH_USE_CEREAL "Build the C++ API with Cereal for serialization support" OFF)
# Used when building Caffe2 through setup.py
option(BUILDING_WITH_TORCH_LIBS "Tell cmake if Caffe2 is being built alongside torch libs" OFF)

View File

@ -125,7 +125,6 @@ function (caffe2_print_configuration_summary)
message(STATUS " USE_GLOO : ${USE_GLOO}")
message(STATUS " USE_GLOO_IBVERBS : ${USE_GLOO_IBVERBS}")
endif()
message(STATUS " TORCH_USE_CEREAL : ${TORCH_USE_CEREAL}")
message(STATUS " Public Dependencies : ${Caffe2_PUBLIC_DEPENDENCY_LIBS}")
message(STATUS " Private Dependencies : ${Caffe2_DEPENDENCY_LIBS}")

View File

@ -3,6 +3,8 @@
#define CATCH_CONFIG_PREFIX_ALL
#include <catch.hpp>
// CATCH_REQUIRE_THROWS is not defined identically to REQUIRE_THROWS and causes warning;
// define our own version that doesn't warn.
#define _CATCH_REQUIRE_THROWS( ... ) INTERNAL_CATCH_THROWS( "CATCH_REQUIRE_THROWS", Catch::ResultDisposition::Normal, __VA_ARGS__ )
// CATCH_REQUIRE_THROWS is not defined identically to REQUIRE_THROWS and causes
// warning; define our own version that doesn't warn.
#define _CATCH_REQUIRE_THROWS(...) \
INTERNAL_CATCH_THROWS( \
"CATCH_REQUIRE_THROWS", Catch::ResultDisposition::Normal, __VA_ARGS__)

View File

@ -1,339 +0,0 @@
#include "catch_utils.hpp"
#include <torch/nn/modules/functional.h>
#include <torch/nn/modules/linear.h>
#include <torch/nn/modules/sequential.h>
#include <torch/optim/optimizer.h>
#include <torch/optim/sgd.h>
#include <torch/serialization.h>
#include <torch/tensor.h>
#include <torch/utils.h>
#include <test/cpp/api/util.h>
#include <cereal/archives/portable_binary.hpp>
#include <memory>
#include <sstream>
#include <string>
#include <vector>
using namespace torch::nn;
namespace {
Sequential xor_model() {
return Sequential(
Linear(2, 8),
Functional(at::sigmoid),
Linear(8, 1),
Functional(at::sigmoid));
}
} // namespace
CATCH_TEST_CASE("serialization") {
torch::manual_seed(0);
CATCH_SECTION("undefined") {
auto x = torch::Tensor();
CATCH_REQUIRE(!x.defined());
auto y = torch::randn({5});
std::stringstream ss;
torch::save(ss, &x);
torch::load(ss, &y);
CATCH_REQUIRE(!y.defined());
}
CATCH_SECTION("cputypes") {
for (int i = 0; i < static_cast<int>(torch::Dtype::NumOptions); i++) {
if (i == static_cast<int>(torch::Dtype::Half)) {
// XXX can't serialize half tensors at the moment since contiguous() is
// not implemented for this type;
continue;
} else if (at::isComplexType(static_cast<torch::Dtype>(i))) {
// Not supported yet
continue;
} else if (i == static_cast<int>(torch::Dtype::Undefined)) {
// We can't construct a tensor for this type. This is tested in
// serialization/undefined anyway.
continue;
}
auto x = torch::ones(
{5, 5}, static_cast<torch::Dtype>(i));
auto y = torch::empty({});
std::stringstream ss;
torch::save(ss, &x);
torch::load(ss, &y);
CATCH_REQUIRE(y.defined());
CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
if (torch::isIntegralType(static_cast<torch::Dtype>(i))) {
CATCH_REQUIRE(x.equal(y));
} else {
CATCH_REQUIRE(x.allclose(y));
}
}
}
CATCH_SECTION("binary") {
auto x = torch::randn({5, 5});
auto y = torch::Tensor();
std::stringstream ss;
{
cereal::BinaryOutputArchive archive(ss);
archive(x);
}
{
cereal::BinaryInputArchive archive(ss);
archive(y);
}
CATCH_REQUIRE(y.defined());
CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
CATCH_REQUIRE(x.allclose(y));
}
CATCH_SECTION("portable_binary") {
auto x = torch::randn({5, 5});
auto y = torch::Tensor();
std::stringstream ss;
{
cereal::PortableBinaryOutputArchive archive(ss);
archive(x);
}
{
cereal::PortableBinaryInputArchive archive(ss);
archive(y);
}
CATCH_REQUIRE(y.defined());
CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
CATCH_REQUIRE(x.allclose(y));
}
CATCH_SECTION("resized") {
auto x = torch::randn({11, 5});
x.resize_({5, 5});
auto y = torch::Tensor();
std::stringstream ss;
{
cereal::BinaryOutputArchive archive(ss);
archive(x);
}
{
cereal::BinaryInputArchive archive(ss);
archive(y);
}
CATCH_REQUIRE(y.defined());
CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
CATCH_REQUIRE(x.allclose(y));
}
CATCH_SECTION("sliced") {
auto x = torch::randn({11, 5});
x = x.slice(0, 1, 3);
auto y = torch::Tensor();
std::stringstream ss;
{
cereal::BinaryOutputArchive archive(ss);
archive(x);
}
{
cereal::BinaryInputArchive archive(ss);
archive(y);
}
CATCH_REQUIRE(y.defined());
CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
CATCH_REQUIRE(x.allclose(y));
}
CATCH_SECTION("noncontig") {
auto x = torch::randn({11, 5});
x = x.slice(1, 1, 4);
auto y = torch::Tensor();
std::stringstream ss;
{
cereal::BinaryOutputArchive archive(ss);
archive(x);
}
{
cereal::BinaryInputArchive archive(ss);
archive(y);
}
CATCH_REQUIRE(y.defined());
CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
CATCH_REQUIRE(x.allclose(y));
}
CATCH_SECTION("xor") {
// We better be able to save and load a XOR model!
auto getLoss = [](Sequential model, uint32_t batch_size) {
auto inputs = torch::empty({batch_size, 2});
auto labels = torch::empty({batch_size});
for (size_t i = 0; i < batch_size; i++) {
inputs[i] = torch::randint(2, {2}, torch::kInt64);
labels[i] = inputs[i][0].toCLong() ^ inputs[i][1].toCLong();
}
auto x = model->forward<torch::Tensor>(inputs);
return torch::binary_cross_entropy(x, labels);
};
auto model = xor_model();
auto model2 = xor_model();
auto model3 = xor_model();
auto optimizer = torch::optim::SGD(
model->parameters(),
torch::optim::SGDOptions(1e-1)
.momentum(0.9)
.nesterov(true)
.weight_decay(1e-6));
float running_loss = 1;
int epoch = 0;
while (running_loss > 0.1) {
torch::Tensor loss = getLoss(model, 4);
optimizer.zero_grad();
loss.backward();
optimizer.step();
running_loss = running_loss * 0.99 + loss.sum().toCFloat() * 0.01;
CATCH_REQUIRE(epoch < 3000);
epoch++;
}
std::stringstream ss;
torch::save(ss, model);
torch::load(ss, model2);
auto loss = getLoss(model2, 100);
CATCH_REQUIRE(loss.toCFloat() < 0.1);
}
CATCH_SECTION("optim") {
auto model1 = Linear(5, 2);
auto model2 = Linear(5, 2);
auto model3 = Linear(5, 2);
// Models 1, 2, 3 will have the same params
std::stringstream ss;
torch::save(ss, model1.get());
torch::load(ss, model2.get());
ss.seekg(0, std::ios::beg);
torch::load(ss, model3.get());
auto param1 = model1->parameters();
auto param2 = model2->parameters();
auto param3 = model3->parameters();
for (const auto& p : param1) {
CATCH_REQUIRE(param1[p.key].allclose(param2[p.key]));
CATCH_REQUIRE(param2[p.key].allclose(param3[p.key]));
}
// Make some optimizers with momentum (and thus state)
auto optim1 = torch::optim::SGD(
model1->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
auto optim2 = torch::optim::SGD(
model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
auto optim2_2 = torch::optim::SGD(
model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
auto optim3 = torch::optim::SGD(
model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
auto optim3_2 = torch::optim::SGD(
model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
auto x = torch::ones({10, 5});
auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
optimizer.zero_grad();
auto y = model->forward(x).sum();
y.backward();
optimizer.step();
};
// Do 2 steps of model1
step(optim1, model1);
step(optim1, model1);
// Do 2 steps of model 2 without saving the optimizer
step(optim2, model2);
step(optim2_2, model2);
// Do 2 steps of model 3 while saving the optimizer
step(optim3, model3);
ss.clear();
torch::save(ss, &optim3);
torch::load(ss, &optim3_2);
step(optim3_2, model3);
param1 = model1->parameters();
param2 = model2->parameters();
param3 = model3->parameters();
for (const auto& p : param1) {
const auto& name = p.key;
// Model 1 and 3 should be the same
CATCH_REQUIRE(param1[name].norm().toCFloat() == param3[name].norm().toCFloat());
CATCH_REQUIRE(param1[name].norm().toCFloat() != param2[name].norm().toCFloat());
}
}
}
CATCH_TEST_CASE("serialization_cuda", "[cuda]") {
torch::manual_seed(0);
// We better be able to save and load a XOR model!
auto getLoss = [](Sequential model, uint32_t batch_size) {
auto inputs = torch::empty({batch_size, 2});
auto labels = torch::empty({batch_size});
for (size_t i = 0; i < batch_size; i++) {
inputs[i] = torch::randint(2, {2}, torch::kInt64);
labels[i] = inputs[i][0].toCLong() ^ inputs[i][1].toCLong();
}
auto x = model->forward<torch::Tensor>(inputs);
return torch::binary_cross_entropy(x, labels);
};
auto model = xor_model();
auto model2 = xor_model();
auto model3 = xor_model();
auto optimizer = torch::optim::SGD(
model->parameters(),
torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay(
1e-6));
float running_loss = 1;
int epoch = 0;
while (running_loss > 0.1) {
torch::Tensor loss = getLoss(model, 4);
optimizer.zero_grad();
loss.backward();
optimizer.step();
running_loss = running_loss * 0.99 + loss.sum().toCFloat() * 0.01;
CATCH_REQUIRE(epoch < 3000);
epoch++;
}
std::stringstream ss;
torch::save(ss, model);
torch::load(ss, model2);
auto loss = getLoss(model2, 100);
CATCH_REQUIRE(loss.toCFloat() < 0.1);
model2->to(torch::kCUDA);
ss.clear();
torch::save(ss, model2);
torch::load(ss, model3);
loss = getLoss(model3, 100);
CATCH_REQUIRE(loss.toCFloat() < 0.1);
}

246
test/cpp/api/serialize.cpp Normal file
View File

@ -0,0 +1,246 @@
#include <test/cpp/api/catch_utils.hpp>
#include <torch/nn/modules/functional.h>
#include <torch/nn/modules/linear.h>
#include <torch/nn/modules/sequential.h>
#include <torch/optim/optimizer.h>
#include <torch/optim/sgd.h>
#include <torch/serialize.h>
#include <torch/tensor.h>
#include <torch/utils.h>
#include <test/cpp/api/util.h>
#include <cstdio>
#include <memory>
#include <sstream>
#include <string>
#include <vector>
using namespace torch::nn;
using namespace torch::serialize;
namespace {
Sequential xor_model() {
return Sequential(
Linear(2, 8),
Functional(at::sigmoid),
Linear(8, 1),
Functional(at::sigmoid));
}
torch::Tensor save_and_load(torch::Tensor input) {
torch::test::TempFile tempfile;
torch::save(input, tempfile.str());
return torch::load(tempfile.str());
}
} // namespace
CATCH_TEST_CASE("Serialize/Default/Basic") {
torch::manual_seed(0);
auto x = torch::randn({5, 5});
auto y = save_and_load(x);
CATCH_REQUIRE(y.defined());
CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
CATCH_REQUIRE(x.allclose(y));
}
CATCH_TEST_CASE("Serialize/Default/Resized") {
torch::manual_seed(0);
auto x = torch::randn({11, 5});
x.resize_({5, 5});
auto y = save_and_load(x);
CATCH_REQUIRE(y.defined());
CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
CATCH_REQUIRE(x.allclose(y));
}
CATCH_TEST_CASE("Serialize/Default/Sliced") {
torch::manual_seed(0);
auto x = torch::randn({11, 5});
x = x.slice(0, 1, 5);
auto y = save_and_load(x);
CATCH_REQUIRE(y.defined());
CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
CATCH_REQUIRE(x.allclose(y));
}
CATCH_TEST_CASE("Serialize/Default/NonContiguous") {
torch::manual_seed(0);
auto x = torch::randn({11, 5});
x = x.slice(1, 1, 4);
auto y = save_and_load(x);
CATCH_REQUIRE(y.defined());
CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
CATCH_REQUIRE(x.allclose(y));
}
CATCH_TEST_CASE("Serialize/Default/XOR") {
// We better be able to save and load an XOR model!
auto getLoss = [](Sequential model, uint32_t batch_size) {
auto inputs = torch::empty({batch_size, 2});
auto labels = torch::empty({batch_size});
for (size_t i = 0; i < batch_size; i++) {
inputs[i] = torch::randint(2, {2}, torch::kInt64);
labels[i] = inputs[i][0].toCLong() ^ inputs[i][1].toCLong();
}
auto x = model->forward<torch::Tensor>(inputs);
return torch::binary_cross_entropy(x, labels);
};
auto model = xor_model();
auto model2 = xor_model();
auto model3 = xor_model();
auto optimizer = torch::optim::SGD(
model->parameters(),
torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay(
1e-6));
float running_loss = 1;
int epoch = 0;
while (running_loss > 0.1) {
torch::Tensor loss = getLoss(model, 4);
optimizer.zero_grad();
loss.backward();
optimizer.step();
running_loss = running_loss * 0.99 + loss.sum().toCFloat() * 0.01;
CATCH_REQUIRE(epoch < 3000);
epoch++;
}
torch::test::TempFile tempfile;
torch::save(model, tempfile.str());
torch::load(model2, tempfile.str());
auto loss = getLoss(model2, 100);
CATCH_REQUIRE(loss.toCFloat() < 0.1);
}
CATCH_TEST_CASE("Serialize/Default/Optim") {
auto model1 = Linear(5, 2);
auto model2 = Linear(5, 2);
auto model3 = Linear(5, 2);
// Models 1, 2, 3 will have the same parameters.
torch::test::TempFile model_tempfile;
torch::save(model1, model_tempfile.str());
torch::load(model2, model_tempfile.str());
torch::load(model3, model_tempfile.str());
auto param1 = model1->parameters();
auto param2 = model2->parameters();
auto param3 = model3->parameters();
for (const auto& p : param1) {
CATCH_REQUIRE(param1[p.key].allclose(param2[p.key]));
CATCH_REQUIRE(param2[p.key].allclose(param3[p.key]));
}
// Make some optimizers with momentum (and thus state)
auto optim1 = torch::optim::SGD(
model1->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
auto optim2 = torch::optim::SGD(
model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
auto optim2_2 = torch::optim::SGD(
model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
auto optim3 = torch::optim::SGD(
model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
auto optim3_2 = torch::optim::SGD(
model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
auto x = torch::ones({10, 5});
auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
optimizer.zero_grad();
auto y = model->forward(x).sum();
y.backward();
optimizer.step();
};
// Do 2 steps of model1
step(optim1, model1);
step(optim1, model1);
// Do 2 steps of model 2 without saving the optimizer
step(optim2, model2);
step(optim2_2, model2);
// Do 2 steps of model 3 while saving the optimizer
step(optim3, model3);
torch::test::TempFile optim_tempfile;
torch::save(optim3, optim_tempfile.str());
torch::load(optim3_2, optim_tempfile.str());
step(optim3_2, model3);
param1 = model1->parameters();
param2 = model2->parameters();
param3 = model3->parameters();
for (const auto& p : param1) {
const auto& name = p.key;
// Model 1 and 3 should be the same
CATCH_REQUIRE(
param1[name].norm().toCFloat() == param3[name].norm().toCFloat());
CATCH_REQUIRE(
param1[name].norm().toCFloat() != param2[name].norm().toCFloat());
}
}
CATCH_TEST_CASE("Serialize/Default/CUDA", "[cuda]") {
torch::manual_seed(0);
// We better be able to save and load a XOR model!
auto getLoss = [](Sequential model, uint32_t batch_size) {
auto inputs = torch::empty({batch_size, 2});
auto labels = torch::empty({batch_size});
for (size_t i = 0; i < batch_size; i++) {
inputs[i] = torch::randint(2, {2}, torch::kInt64);
labels[i] = inputs[i][0].toCLong() ^ inputs[i][1].toCLong();
}
auto x = model->forward<torch::Tensor>(inputs);
return torch::binary_cross_entropy(x, labels);
};
auto model = xor_model();
auto model2 = xor_model();
auto model3 = xor_model();
auto optimizer = torch::optim::SGD(
model->parameters(),
torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay(
1e-6));
float running_loss = 1;
int epoch = 0;
while (running_loss > 0.1) {
torch::Tensor loss = getLoss(model, 4);
optimizer.zero_grad();
loss.backward();
optimizer.step();
running_loss = running_loss * 0.99 + loss.sum().toCFloat() * 0.01;
CATCH_REQUIRE(epoch < 3000);
epoch++;
}
torch::test::TempFile tempfile;
torch::save(model, tempfile.str());
torch::load(model2, tempfile.str());
auto loss = getLoss(model2, 100);
CATCH_REQUIRE(loss.toCFloat() < 0.1);
model2->to(torch::kCUDA);
torch::test::TempFile tempfile2;
torch::save(model2, tempfile2.str());
torch::load(model3, tempfile2.str());
loss = getLoss(model3, 100);
CATCH_REQUIRE(loss.toCFloat() < 0.1);
}

View File

@ -1,10 +1,17 @@
#pragma once
#include <torch/nn/cloneable.h>
#include <torch/tensor.h>
#include <cstdio>
#include <cstdlib>
#include <string>
#include <utility>
#ifndef WIN32
#include <unistd.h>
#endif
namespace torch {
namespace test {
@ -25,5 +32,36 @@ class SimpleContainer : public nn::Cloneable<SimpleContainer> {
inline bool pointer_equal(at::Tensor first, at::Tensor second) {
return first.data<float>() == second.data<float>();
}
#ifdef WIN32
struct TempFile {
TempFile() : filename_(std::tmpnam(nullptr)) {}
const std::string& str() const {
return filename_;
}
std::string filename_;
};
#else
struct TempFile {
TempFile() {
// http://pubs.opengroup.org/onlinepubs/009695399/functions/mkstemp.html
char filename[] = "/tmp/fileXXXXXX";
fd_ = mkstemp(filename);
AT_CHECK(fd_ != -1, "Error creating tempfile");
filename_.assign(filename);
}
~TempFile() {
close(fd_);
}
const std::string& str() const {
return filename_;
}
std::string filename_;
int fd_;
};
#endif
} // namespace test
} // namespace torch

View File

@ -9,7 +9,6 @@ from setup_helpers.cuda import USE_CUDA
if __name__ == '__main__':
# Placeholder for future interface. For now just gives a nice -h.
parser = argparse.ArgumentParser(description='Build libtorch')
parser.add_argument('--use-cereal', action='store_true')
options = parser.parse_args()
os.environ['BUILD_TORCH'] = 'ON'
@ -25,8 +24,6 @@ if __name__ == '__main__':
command.append('--use-cuda')
if os.environ.get('USE_CUDA_STATIC_LINK', False):
command.append('--cuda-static-link')
if options.use_cereal:
command.append('--use-cereal')
command.append('caffe2')
sys.stdout.flush()

View File

@ -22,7 +22,6 @@ USE_NNPACK=0
USE_MKLDNN=0
USE_GLOO_IBVERBS=0
CAFFE2_STATIC_LINK_CUDA=0
TORCH_USE_CEREAL=0
RERUN_CMAKE=1
while [[ $# -gt 0 ]]; do
case "$1" in
@ -47,9 +46,6 @@ while [[ $# -gt 0 ]]; do
--cuda-static-link)
CAFFE2_STATIC_LINK_CUDA=1
;;
--use-cereal)
TORCH_USE_CEREAL=1
;;
*)
break
;;
@ -194,7 +190,6 @@ function build() {
-DTHCUNN_SO_VERSION=1 \
-DTHD_SO_VERSION=1 \
-DUSE_CUDA=$USE_CUDA \
-DTORCH_USE_CEREAL=$TORCH_USE_CEREAL \
-DBUILD_EXAMPLES=OFF \
-DBUILD_TEST=$BUILD_TEST \
-DNO_NNPACK=$((1-$USE_NNPACK)) \

View File

@ -220,8 +220,8 @@ CONFIGURE_FILE(
if (NOT NO_API AND NOT USE_ROCM)
list(APPEND TORCH_SRCS
${TORCH_SRC_DIR}/csrc/api/src/utils.cpp
${TORCH_SRC_DIR}/csrc/api/src/cuda.cpp
${TORCH_SRC_DIR}/csrc/api/src/jit.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/cursor.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/init.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/module.cpp
@ -232,13 +232,17 @@ if (NOT NO_API AND NOT USE_ROCM)
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/functional.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/linear.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/rnn.cpp
${TORCH_SRC_DIR}/csrc/api/src/optim/optimizer.cpp
${TORCH_SRC_DIR}/csrc/api/src/optim/adam.cpp
${TORCH_SRC_DIR}/csrc/api/src/optim/adagrad.cpp
${TORCH_SRC_DIR}/csrc/api/src/optim/adam.cpp
${TORCH_SRC_DIR}/csrc/api/src/optim/lbfgs.cpp
${TORCH_SRC_DIR}/csrc/api/src/optim/optimizer.cpp
${TORCH_SRC_DIR}/csrc/api/src/optim/rmsprop.cpp
${TORCH_SRC_DIR}/csrc/api/src/optim/serialize.cpp
${TORCH_SRC_DIR}/csrc/api/src/optim/sgd.cpp
${TORCH_SRC_DIR}/csrc/api/src/jit.cpp
${TORCH_SRC_DIR}/csrc/api/src/serialize.cpp
${TORCH_SRC_DIR}/csrc/api/src/serialize/input-archive.cpp
${TORCH_SRC_DIR}/csrc/api/src/serialize/output-archive.cpp
${TORCH_SRC_DIR}/csrc/api/src/utils.cpp
)
endif()
@ -330,13 +334,6 @@ if (NOT NO_API AND NOT USE_ROCM)
target_include_directories(torch PUBLIC
${TORCH_SRC_DIR}/csrc/api
${TORCH_SRC_DIR}/csrc/api/include)
if (TORCH_USE_CEREAL)
target_compile_definitions(torch PUBLIC TORCH_USE_CEREAL)
# SYSTEM headers are included with -isystem and thus do not trigger warnings.
target_include_directories(torch SYSTEM PUBLIC
"${TORCH_ROOT}/third_party/cereal/include") # For cereal/
endif()
endif()
if(USE_CUDA)
@ -445,6 +442,7 @@ if (BUILD_TEST AND NOT NO_API AND NOT USE_ROCM)
${TORCH_API_TEST_DIR}/any.cpp
${TORCH_API_TEST_DIR}/cursor.cpp
${TORCH_API_TEST_DIR}/integration.cpp
${TORCH_API_TEST_DIR}/jit.cpp
${TORCH_API_TEST_DIR}/main.cpp
${TORCH_API_TEST_DIR}/misc.cpp
${TORCH_API_TEST_DIR}/module.cpp
@ -453,17 +451,13 @@ if (BUILD_TEST AND NOT NO_API AND NOT USE_ROCM)
${TORCH_API_TEST_DIR}/parallel.cpp
${TORCH_API_TEST_DIR}/rnn.cpp
${TORCH_API_TEST_DIR}/sequential.cpp
${TORCH_API_TEST_DIR}/serialize.cpp
${TORCH_API_TEST_DIR}/tensor_cuda.cpp
${TORCH_API_TEST_DIR}/tensor.cpp
${TORCH_API_TEST_DIR}/jit.cpp
${TORCH_API_TEST_DIR}/tensor_options.cpp
${TORCH_API_TEST_DIR}/tensor_options_cuda.cpp
${TORCH_API_TEST_DIR}/tensor_options.cpp
${TORCH_API_TEST_DIR}/tensor.cpp
)
if (TORCH_USE_CEREAL)
list(APPEND TORCH_API_TEST_SOURCES ${TORCH_API_TEST_DIR}/serialization.cpp)
endif()
add_executable(test_api ${TORCH_API_TEST_SOURCES})
target_include_directories(test_api

View File

@ -1,5 +1,8 @@
#pragma once
#include <torch/csrc/utils/variadic.h>
#include <torch/tensor.h>
#include <cstdint>
#include <type_traits>
@ -52,7 +55,8 @@ inline constexpr bool check_not_lvalue_references<void>() {
/// A type trait whose `value` member is true if `M` derives from `Module`.
template <typename M>
using is_module = std::is_base_of<torch::nn::Module, typename std::decay<M>::type>;
using is_module =
std::is_base_of<torch::nn::Module, typename std::decay<M>::type>;
template <typename M, typename T = void>
using enable_if_module_t =

View File

@ -3,6 +3,7 @@
#include <torch/detail/ordered_dict.h>
#include <torch/nn/cursor.h>
#include <torch/nn/pimpl.h>
#include <torch/serialize/archive.h>
#include <torch/tensor.h>
#include <ATen/ATen.h>
@ -208,14 +209,6 @@ class Module {
/// Recursively zeros out the `grad` value of each registered parameter.
virtual void zero_grad();
/// Serializes the `Module`.
template <class Archive>
void save(Archive& ar) const;
/// Deserializes the `Module`.
template <class Archive>
void load(Archive& ar);
/// Attempts to cast this `Module` to the given `ModuleType`.
///
/// This method is useful when calling `apply()` on a `ModuleCursor`.
@ -255,6 +248,12 @@ class Module {
typename = torch::detail::disable_if_module_holder_t<ModuleType>>
ModuleType* as() noexcept;
/// Serializes the `Module` into the given `OutputArchive`.
virtual void save(serialize::OutputArchive& archive) const;
/// Deserializes the `Module` from the given `InputArchive`.
virtual void load(serialize::InputArchive& archive);
protected:
/// Registers a parameter with this `Module`.
///
@ -359,28 +358,6 @@ class Module {
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ nn::Module ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <class Archive>
void Module::save(Archive& ar) const {
auto params = parameters();
size_t size = params.size();
ar(size);
for (auto& p : params) {
ar(p.key, p.value);
}
}
template <class Archive>
void Module::load(Archive& ar) {
auto params = parameters();
size_t size;
ar(size);
std::string name;
for (size_t i = 0; i < size; i++) {
ar(name);
ar(params[name]);
}
}
template <typename ModuleType>
typename ModuleType::ContainedType* Module::as() noexcept {
// Use the contained type of the `ModuleHolder`, e.g. `LinearImpl` for

View File

@ -1,15 +1,20 @@
#pragma once
#include <torch/nn/module.h>
#include <torch/nn/pimpl.h>
#include <torch/optim/optimizer.h>
#include <torch/serialization.h>
#include <torch/optim/serialize.h>
#include <torch/tensor.h>
#include <ATen/ATen.h>
#include <utility>
#include <vector>
namespace torch {
namespace serialize {
class OutputArchive;
class InputArchive;
} // namespace serialize
} // namespace torch
namespace torch {
namespace optim {
@ -33,29 +38,20 @@ class Adagrad : public Optimizer {
AdagradOptions options;
template <class Archive>
void serialize(Archive& ar) {
#if defined(TORCH_USE_CEREAL)
ar(CEREAL_NVP(sum_));
ar(CEREAL_NVP(step_));
#endif // defined(TORCH_USE_CEREAL)
}
void save(serialize::OutputArchive& archive) const override;
void load(serialize::InputArchive& archive) override;
std::vector<Tensor> sum_buffers;
std::vector<int64_t> step_buffers;
private:
#if defined(TORCH_USE_CEREAL)
friend class cereal::access;
#endif // defined(TORCH_USE_CEREAL)
Adagrad() : options(0) {}
std::vector<Tensor> sum_;
std::vector<int64_t> step_;
template <typename Self, typename Archive>
static void serialize(Self& self, Archive& archive) {
TORCH_OPTIM_SERIALIZE(sum_buffers);
TORCH_OPTIM_SERIALIZE(step_buffers);
}
};
} // namespace optim
} // namespace torch
#if defined(TORCH_USE_CEREAL)
CEREAL_REGISTER_TYPE(torch::optim::Adagrad);
CEREAL_REGISTER_POLYMORPHIC_RELATION(
torch::optim::Optimizer,
torch::optim::Adagrad);
#endif // defined(TORCH_USE_CEREAL)

View File

@ -3,13 +3,18 @@
#include <torch/arg.h>
#include <torch/nn/module.h>
#include <torch/optim/optimizer.h>
#include <torch/serialization.h>
#include <ATen/ATen.h>
#include <torch/optim/serialize.h>
#include <utility>
#include <vector>
namespace torch {
namespace serialize {
class OutputArchive;
class InputArchive;
} // namespace serialize
} // namespace torch
namespace torch {
namespace optim {
@ -32,35 +37,26 @@ class Adam : public Optimizer {
void step() override;
template <class Archive>
void serialize(Archive& ar) {
#if defined(TORCH_USE_CEREAL)
ar(CEREAL_NVP(step_buffers_),
CEREAL_NVP(exp_average_buffers_),
CEREAL_NVP(exp_average_sq_buffers_),
CEREAL_NVP(max_exp_average_sq_buffers_));
#endif // defined(TORCH_USE_CEREAL)
}
void save(serialize::OutputArchive& archive) const override;
void load(serialize::InputArchive& archive) override;
AdamOptions options;
std::vector<int64_t> step_buffers;
std::vector<Tensor> exp_average_buffers;
std::vector<Tensor> exp_average_sq_buffers;
std::vector<Tensor> max_exp_average_sq_buffers;
private:
#if defined(TORCH_USE_CEREAL)
friend class cereal::access;
#endif // defined(TORCH_USE_CEREAL)
Adam() : options(0) {}
std::vector<int64_t> step_buffers_;
std::vector<Tensor> exp_average_buffers_;
std::vector<Tensor> exp_average_sq_buffers_;
std::vector<Tensor> max_exp_average_sq_buffers_;
template <typename Self, typename Archive>
static void serialize(Self& self, Archive& archive) {
TORCH_OPTIM_SERIALIZE(step_buffers);
TORCH_OPTIM_SERIALIZE(exp_average_buffers);
TORCH_OPTIM_SERIALIZE(exp_average_sq_buffers);
TORCH_OPTIM_SERIALIZE(max_exp_average_sq_buffers);
}
};
} // namespace optim
} // namespace torch
#if defined(TORCH_USE_CEREAL)
CEREAL_REGISTER_TYPE(torch::optim::Adam);
CEREAL_REGISTER_POLYMORPHIC_RELATION(
torch::optim::Optimizer,
torch::optim::Adam);
#endif // defined(TORCH_USE_CEREAL)

View File

@ -3,9 +3,8 @@
#include <torch/arg.h>
#include <torch/nn/module.h>
#include <torch/optim/optimizer.h>
#include <torch/serialization.h>
#include <ATen/ATen.h>
#include <torch/optim/serialize.h>
#include <torch/serialize/archive.h>
#include <deque>
#include <functional>
@ -38,39 +37,37 @@ class LBFGS : public LossClosureOptimizer {
LBFGSOptions options;
template <class Archive>
void serialize(Archive& ar) {
#if defined(TORCH_USE_CEREAL)
ar(CEREAL_NVP(d));
ar(CEREAL_NVP(t));
ar(CEREAL_NVP(H_diag));
ar(CEREAL_NVP(prev_flat_grad));
ar(CEREAL_NVP(prev_loss));
ar(CEREAL_NVP(old_dirs));
ar(CEREAL_NVP(old_stps));
#endif // defined(TORCH_USE_CEREAL)
}
private:
#if defined(TORCH_USE_CEREAL)
friend class cereal::access;
#endif // defined(TORCH_USE_CEREAL)
LBFGS() : options(0) {}
Tensor gather_flat_grad();
void add_grad(const torch::Scalar& step_size, const Tensor& update);
void save(serialize::OutputArchive& archive) const override;
void load(serialize::InputArchive& archive) override;
Tensor d{torch::empty({0})};
Tensor H_diag{torch::empty({0})};
Tensor prev_flat_grad{torch::empty({0})};
torch::Scalar t{0};
torch::Scalar prev_loss{0};
Tensor t{torch::zeros(1)};
Tensor prev_loss{torch::zeros(1)};
std::vector<Tensor> ro;
std::vector<Tensor> al;
std::deque<Tensor> old_dirs;
std::deque<Tensor> old_stps;
int64_t func_evals{0};
int64_t state_n_iter{0};
private:
LBFGS() : options(0) {}
Tensor gather_flat_grad();
void add_grad(const torch::Tensor& step_size, const Tensor& update);
template <typename Self, typename Archive>
static void serialize(Self& self, Archive& archive) {
archive("d", self.d, /*is_buffer=*/true);
archive("t", self.t, /*is_buffer=*/true);
archive("H_diag", self.H_diag, /*is_buffer=*/true);
archive("prev_flat_grad", self.prev_flat_grad, /*is_buffer=*/true);
archive("prev_loss", self.prev_loss, /*is_buffer=*/true);
detail::serialize(archive, "old_dirs", self.old_dirs);
detail::serialize(archive, "old_stps", self.old_stps);
}
};
} // namespace optim
} // namespace torch

View File

@ -1,18 +1,22 @@
#pragma once
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/utils/variadic.h>
#include <torch/nn/cursor.h>
#include <torch/optim/serialize.h>
#include <torch/serialize/archive.h>
#include <torch/tensor.h>
#include <algorithm>
#include <functional>
#include <iterator>
#include <memory>
#include <string>
#include <vector>
namespace torch {
namespace optim {
namespace detail {
/// Base class for all optimizers, that does not yet define a `step()`
/// mechanism. All it specifies is that optimizers must be supplied with a
/// vector of parameters. It also defines certain methods that all optimizers
@ -48,6 +52,9 @@ class OptimizerBase {
/// Returns the number of parameters referenced by the optimizer.
size_t size() const noexcept;
virtual void save(serialize::OutputArchive& archive) const;
virtual void load(serialize::InputArchive& archive);
protected:
OptimizerBase() = default;

View File

@ -3,15 +3,21 @@
#include <torch/arg.h>
#include <torch/nn/module.h>
#include <torch/optim/optimizer.h>
#include <torch/serialization.h>
#include <ATen/ATen.h>
#include <torch/optim/serialize.h>
#include <torch/tensor.h>
#include <functional>
#include <memory>
#include <string>
#include <vector>
namespace torch {
namespace serialize {
class OutputArchive;
class InputArchive;
} // namespace serialize
} // namespace torch
namespace torch {
namespace optim {
@ -38,31 +44,22 @@ class RMSprop : public Optimizer {
RMSpropOptions options;
template <class Archive>
void serialize(Archive& ar) {
#if defined(TORCH_USE_CEREAL)
ar(CEREAL_NVP(square_average_buffers_));
ar(CEREAL_NVP(momentum_buffers_));
ar(CEREAL_NVP(grad_average_buffers_));
#endif // defined(TORCH_USE_CEREAL)
}
void save(serialize::OutputArchive& archive) const override;
void load(serialize::InputArchive& archive) override;
std::vector<Tensor> square_average_buffers;
std::vector<Tensor> momentum_buffers;
std::vector<Tensor> grad_average_buffers;
private:
#if defined(TORCH_USE_CEREAL)
friend class cereal::access;
#endif // defined(TORCH_USE_CEREAL)
RMSprop() : options(0) {}
std::vector<Tensor> square_average_buffers_;
std::vector<Tensor> momentum_buffers_;
std::vector<Tensor> grad_average_buffers_;
template <typename Self, typename Archive>
static void serialize(Self& self, Archive& archive) {
TORCH_OPTIM_SERIALIZE(square_average_buffers);
TORCH_OPTIM_SERIALIZE(momentum_buffers);
TORCH_OPTIM_SERIALIZE(grad_average_buffers);
}
};
} // namespace optim
} // namespace torch
#if defined(TORCH_USE_CEREAL)
CEREAL_REGISTER_TYPE(torch::optim::RMSprop);
CEREAL_REGISTER_POLYMORPHIC_RELATION(
torch::optim::Optimizer,
torch::optim::RMSprop);
#endif // defined(TORCH_USE_CEREAL)

View File

@ -0,0 +1,67 @@
#pragma once
#include <torch/serialize/archive.h>
#include <torch/tensor.h>
#include <cstddef>
#include <cstdint>
#include <deque>
#include <string>
#include <vector>
namespace torch {
namespace optim {
namespace detail {
// Note: These functions are all called `serialize()` so they can be called
// inside a template where the archive type is a template type and can thus be
// passed such that the appropriate overload is selected.
/// Utility function to save a vector of step buffers.
void serialize(
serialize::OutputArchive& archive,
const std::string& key,
const std::vector<int64_t>& steps);
/// Utility function to load a vector of step buffers.
void serialize(
serialize::InputArchive& archive,
const std::string& key,
std::vector<int64_t>& steps);
/// Utility function to save a vector of buffers.
template <typename BufferContainer>
void serialize(
serialize::OutputArchive& archive,
const std::string& key,
const BufferContainer& buffers) {
archive.write(
key + "/size", torch::tensor(static_cast<int64_t>(buffers.size())));
for (size_t index = 0; index < buffers.size(); ++index) {
archive.write(
key + "/" + std::to_string(index), buffers[index], /*is_buffer=*/true);
}
}
/// Utility function to load a vector of buffers.
template <typename BufferContainer>
void serialize(
serialize::InputArchive& archive,
const std::string& key,
BufferContainer& buffers) {
torch::Tensor size_tensor;
archive.read(key + "/size", size_tensor);
const size_t size = size_tensor.toCLong();
for (size_t index = 0; index < size; ++index) {
buffers.emplace_back();
archive.read(
key + "/" + std::to_string(index), buffers.back(), /*is_buffer=*/true);
}
}
#define TORCH_OPTIM_SERIALIZE(name) \
torch::optim::detail::serialize(archive, #name, self.name)
} // namespace detail
} // namespace optim
} // namespace torch

View File

@ -3,15 +3,19 @@
#include <torch/arg.h>
#include <torch/nn/module.h>
#include <torch/optim/optimizer.h>
#include <torch/serialization.h>
#include <torch/tensor.h>
#include <ATen/ATen.h>
#include <cstddef>
#include <utility>
#include <vector>
namespace torch {
namespace serialize {
class OutputArchive;
class InputArchive;
} // namespace serialize
} // namespace torch
namespace torch {
namespace optim {
@ -33,31 +37,18 @@ class SGD : public Optimizer {
void step() override;
template <class Archive>
void serialize(Archive& ar) {
#if defined(TORCH_USE_CEREAL)
ar(CEREAL_NVP(momentum_buffers_));
#endif // defined(TORCH_USE_CEREAL)
}
void save(serialize::OutputArchive& archive) const override;
void load(serialize::InputArchive& archive) override;
SGDOptions options;
std::vector<Tensor> momentum_buffers;
private:
#if defined(TORCH_USE_CEREAL)
friend class cereal::access;
#endif // defined(TORCH_USE_CEREAL)
SGD() : options(0) {}
std::vector<Tensor> momentum_buffers_;
/// Counts how often `step()` is called, for dampening.
size_t iteration_{0};
};
} // namespace optim
} // namespace torch
#if defined(TORCH_USE_CEREAL)
CEREAL_REGISTER_TYPE(torch::optim::SGD);
CEREAL_REGISTER_POLYMORPHIC_RELATION(
torch::optim::Optimizer,
torch::optim::SGD);
#endif // defined(TORCH_USE_CEREAL)

View File

@ -1,276 +0,0 @@
#pragma once
#include <fstream>
#include <torch/tensor.h>
#include <torch/utils.h>
#if defined(TORCH_USE_CEREAL)
#include <cereal/access.hpp>
#include <cereal/cereal.hpp>
#include <cereal/types/polymorphic.hpp>
#include "cereal/archives/binary.hpp"
#include "cereal/types/string.hpp"
#include "cereal/types/unordered_map.hpp"
#include "cereal/types/vector.hpp"
#endif // defined(TORCH_USE_CEREAL)
namespace torch {
// Some convenience functions for saving and loading
template <typename T>
void save(std::ostream& stream, T const& obj) {
#if defined(TORCH_USE_CEREAL)
cereal::BinaryOutputArchive archive(stream);
archive(*obj);
#else
AT_ERROR("PyTorch compiled without serialization support");
#endif
}
template <typename T>
void load(std::istream& stream, T& obj) {
#if defined(TORCH_USE_CEREAL)
cereal::BinaryInputArchive archive(stream);
archive(*obj);
#else
AT_ERROR("PyTorch compiled without serialization support");
#endif
}
template <typename T>
void save(std::ostream& stream, T const* obj) {
#if defined(TORCH_USE_CEREAL)
cereal::BinaryOutputArchive archive(stream);
archive(*obj);
#else
AT_ERROR("PyTorch compiled without serialization support");
#endif
}
template <typename T>
void load(std::istream& stream, T* obj) {
#if defined(TORCH_USE_CEREAL)
cereal::BinaryInputArchive archive(stream);
archive(*obj);
#else
AT_ERROR("PyTorch compiled without serialization support");
#endif
}
template <typename T>
void save(std::string const& path, T const& obj) {
std::ofstream os(path, std::ios::binary);
torch::save(os, obj);
}
template <typename T>
void load(std::string const& path, T& obj) {
std::ifstream is(path, std::ios::binary);
torch::load(is, obj);
}
namespace detail {
// We use our own hard-coded type<->id mapping so that serialization is robust
// wrt changes in ATen; see e.g. https://git.io/vxd6R
// The mapping is consistent with the ScalarType enum as of pytorch version
// v0.1.11-7675-ge94c67e.
inline int32_t scalarTypeId(torch::Dtype type) {
switch (type) {
case torch::Dtype::Byte:
return 0;
case torch::Dtype::Char:
return 1;
case torch::Dtype::Short:
return 2;
case torch::Dtype::Int:
return 3;
case torch::Dtype::Long:
return 4;
case torch::Dtype::Half:
return 5;
case torch::Dtype::Float:
return 6;
case torch::Dtype::Double:
return 7;
case torch::Dtype::Undefined:
return 8;
default:
AT_ERROR("Unknown scalar type: ", static_cast<int>(type));
}
}
inline torch::Dtype scalarTypeFromId(int32_t id) {
switch (id) {
case 0:
return torch::Dtype::Byte;
case 1:
return torch::Dtype::Char;
case 2:
return torch::Dtype::Short;
case 3:
return torch::Dtype::Int;
case 4:
return torch::Dtype::Long;
case 5:
return torch::Dtype::Half;
case 6:
return torch::Dtype::Float;
case 7:
return torch::Dtype::Double;
case 8:
return torch::Dtype::Undefined;
default:
AT_ERROR("Unknown scalar type id: ", id);
}
}
inline int32_t backendId(at::Backend backend) {
switch (backend) {
case at::Backend::CPU:
return 0;
case at::Backend::CUDA:
return 1;
case at::Backend::SparseCPU:
return 2;
case at::Backend::SparseCUDA:
return 3;
case at::Backend::Undefined:
return 4;
default:
AT_ERROR("Unknown backend: ", static_cast<int>(backend));
}
}
inline at::Backend backendFromId(int32_t id) {
switch (id) {
case 0:
return at::Backend::CPU;
case 1:
return at::Backend::CUDA;
case 2:
return at::Backend::SparseCPU;
case 3:
return at::Backend::SparseCUDA;
case 4:
return at::Backend::Undefined;
default:
AT_ERROR("Unknown backend id: ", id);
}
}
} // namespace detail
} // namespace torch
#if defined(TORCH_USE_CEREAL)
namespace cereal {
namespace agimpl {
template <class Archive>
void saveBinary(Archive& archive, void const* data, size_t size) {
// In general, there's no direct `saveBinary`-like method on archives
std::vector<char> v(
static_cast<char const*>(data), static_cast<char const*>(data) + size);
archive(v);
}
template <>
inline void saveBinary(
BinaryOutputArchive& archive,
void const* data,
size_t size) {
// Writes to output stream without extra copy
archive.saveBinary(data, size);
}
template <class Archive>
void loadBinary(Archive& archive, void* data, size_t size) {
// In general, there's no direct `loadBinary`-like method on archives
std::vector<char> v(size);
archive(v);
std::memcpy(data, v.data(), size);
}
template <>
inline void loadBinary(BinaryInputArchive& archive, void* data, size_t size) {
// Read from input stream without extra copy
archive.loadBinary(data, size);
}
} // namespace agimpl
// Gradients will not be saved for variables
template <class Archive>
void save(Archive& archive, const torch::Tensor& tensor) {
if (!tensor.defined()) {
int32_t typeId = ::torch::detail::scalarTypeId(torch::Dtype::Undefined);
archive(CEREAL_NVP(typeId));
return;
} else {
int32_t typeId = ::torch::detail::scalarTypeId(tensor.dtype());
archive(CEREAL_NVP(typeId));
}
auto sizes = std::vector<int64_t>();
auto buf = std::vector<uint8_t>();
for (auto s : tensor.sizes()) {
sizes.push_back(s);
}
auto contig = tensor.cpu().contiguous();
int32_t backend = ::torch::detail::backendId(tensor.type().backend());
archive(CEREAL_NVP(backend), CEREAL_NVP(sizes));
agimpl::saveBinary(
archive,
contig.data_ptr(),
tensor.numel() * tensor.type().elementSizeInBytes());
}
/**
* We follow these rules for loading:
* 1. If tensor is defined, and the same ScalarType as the saved tensor,
* then we simply copy the data into the tensor, with resizing.
* 2. Otherwise, overwrite the provided tensor with the right type and backend
**/
template <class Archive>
void load(Archive& archive, torch::Tensor& tensor) {
torch::NoGradGuard guard;
torch::Dtype type;
int32_t typeId;
archive(CEREAL_NVP(typeId));
type = ::torch::detail::scalarTypeFromId(typeId);
if (type == torch::Dtype::Undefined) {
tensor = torch::Tensor();
return;
}
int32_t backendId;
auto sizes = std::vector<int64_t>();
auto buf = std::vector<uint8_t>();
archive(CEREAL_NVP(backendId), CEREAL_NVP(sizes));
at::Backend backend = ::torch::detail::backendFromId(backendId);
if (!tensor.defined() || tensor.dtype() != type) {
tensor = torch::empty({}, at::TensorOptions(backend).dtype(type));
}
const auto required_grad = tensor.requires_grad();
tensor.set_requires_grad(false);
tensor.resize_(sizes);
tensor.set_requires_grad(required_grad);
if (tensor.type().is_cuda()) {
// should actually use cudamemcpy probably
auto cputensor = torch::empty(sizes, tensor.dtype());
agimpl::loadBinary(
archive,
cputensor.data_ptr(),
cputensor.numel() * cputensor.type().elementSizeInBytes());
tensor.copy_(cputensor);
} else {
agimpl::loadBinary(
archive,
tensor.data_ptr(),
tensor.numel() * tensor.type().elementSizeInBytes());
}
}
} // namespace cereal
#endif // defined(TORCH_USE_CEREAL)

View File

@ -0,0 +1,49 @@
#pragma once
#include <torch/nn/module.h>
#include <torch/nn/pimpl.h>
#include <torch/serialize/archive.h>
#include <torch/tensor.h>
#include <string>
#include <utility>
namespace torch {
namespace optim {
class Optimizer;
} // namespace optim
} // namespace torch
namespace torch {
namespace serialize {
template <typename ModuleType>
void save(const nn::ModuleHolder<ModuleType>& module, OutputArchive& archive) {
module->save(archive);
}
template <typename ModuleType>
void load(nn::ModuleHolder<ModuleType>& module, InputArchive& archive) {
module->load(archive);
}
void save(const Tensor& tensor, OutputArchive& archive);
void save(const optim::Optimizer& optimizer, OutputArchive& archive);
void load(optim::Optimizer& optimizer, InputArchive& archive);
} // namespace serialize
template <typename T>
void save(const T& value, const std::string& filename) {
serialize::OutputArchive archive;
serialize::save(value, archive);
serialize::save_to_file(archive, filename);
}
template <typename T>
void load(T& value, const std::string& filename) {
serialize::InputArchive archive = serialize::load_from_file(filename);
serialize::load(value, archive);
}
Tensor load(const std::string& filename);
} // namespace torch

View File

@ -0,0 +1,4 @@
#pragma once
#include <torch/serialize/input-archive.h>
#include <torch/serialize/output-archive.h>

View File

@ -0,0 +1,41 @@
#pragma once
#include <torch/tensor.h>
#include <memory>
#include <string>
#include <utility>
namespace torch {
namespace jit {
namespace script {
struct Module;
} // namespace script
} // namespace jit
} // namespace torch
namespace torch {
namespace serialize {
class InputArchive final {
public:
InputArchive();
void read(const std::string& key, Tensor& tensor, bool is_buffer = false);
void read(const std::string& key, InputArchive& archive);
template <typename... Ts>
void operator()(Ts&&... ts) {
read(std::forward<Ts>(ts)...);
}
private:
friend InputArchive load_from_file(const std::string& filename);
InputArchive(std::shared_ptr<jit::script::Module> module);
std::shared_ptr<jit::script::Module> module_;
};
InputArchive load_from_file(const std::string& filename);
} // namespace serialize
} // namespace torch

View File

@ -0,0 +1,42 @@
#pragma once
#include <torch/tensor.h>
#include <memory>
#include <string>
#include <utility>
namespace torch {
namespace jit {
namespace script {
struct Module;
} // namespace script
} // namespace jit
} // namespace torch
namespace torch {
namespace serialize {
class OutputArchive final {
public:
OutputArchive();
void write(
const std::string& key,
const Tensor& tensor,
bool is_buffer = false);
void write(const std::string& key, OutputArchive& nested_archive);
template <typename... Ts>
void operator()(Ts&&... ts) {
write(std::forward<Ts>(ts)...);
}
private:
friend void save_to_file(const OutputArchive&, const std::string&);
std::shared_ptr<jit::script::Module> module_;
};
void save_to_file(const OutputArchive& archive, const std::string& filename);
} // namespace serialize
} // namespace torch

View File

@ -3,6 +3,6 @@
#include <torch/cuda.h>
#include <torch/nn.h>
#include <torch/optim.h>
#include <torch/serialization.h>
#include <torch/serialize.h>
#include <torch/tensor.h>
#include <torch/utils.h>

View File

@ -119,6 +119,39 @@ void Module::zero_grad() {
}
}
void Module::save(serialize::OutputArchive& archive) const {
for (const auto& parameter : parameters_) {
archive.write(parameter.key, parameter.value);
}
for (const auto& buffer : buffers_) {
archive.write(buffer.key, buffer.value, /*is_buffer=*/true);
}
for (const auto& child : children_) {
serialize::OutputArchive child_archive;
child.value->save(child_archive);
archive.write(child.key, child_archive);
}
}
void Module::load(serialize::InputArchive& archive) {
for (auto& parameter : parameters_) {
archive.read(parameter.key, parameter.value);
}
for (auto& buffer : buffers_) {
archive.read(buffer.key, buffer.value, /*is_buffer=*/true);
}
for (const auto& child : children_) {
// Modules that have no state at all (parameters or buffers) are currently
// not stored in Protobuf at all, so we can just skip them.
if (!child.value->parameters_.is_empty() ||
!child.value->buffers_.is_empty()) {
serialize::InputArchive child_archive;
archive.read(child.key, child_archive);
child.value->load(child_archive);
}
}
}
Tensor& Module::register_parameter(
std::string name,
Tensor tensor,

View File

@ -1,6 +1,7 @@
#include <torch/optim/adagrad.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/serialize/archive.h>
#include <torch/utils.h>
#include <ATen/ATen.h>
@ -26,17 +27,25 @@ void Adagrad::step() {
p.grad() = p.grad() + options.weight_decay_ * p;
}
buffer_at(step_, i) += 1.0;
buffer_at(step_buffers, i) += 1.0;
const auto clr = options.learning_rate_ /
(1.0 + (buffer_at(step_, i) - 1.0) * options.lr_decay_);
(1.0 + (buffer_at(step_buffers, i) - 1.0) * options.lr_decay_);
auto& sum = buffer_at(sum_, i);
auto& sum = buffer_at(sum_buffers, i);
sum.addcmul_(p.grad(), p.grad(), 1.0);
const auto std = buffer_at(sum_, i).sqrt().add_(1e-10);
const auto std = buffer_at(sum_buffers, i).sqrt().add_(1e-10);
NoGradGuard guard;
p.addcdiv_(p.grad(), std, -clr);
}
}
void Adagrad::save(serialize::OutputArchive& archive) const {
serialize(*this, archive);
}
void Adagrad::load(serialize::InputArchive& archive) {
serialize(*this, archive);
}
} // namespace optim
} // namespace torch

View File

@ -2,6 +2,7 @@
#include <torch/csrc/autograd/variable.h>
#include <torch/nn/module.h>
#include <torch/serialize/archive.h>
#include <torch/utils.h>
#include <ATen/ATen.h>
@ -11,7 +12,6 @@
namespace torch {
namespace optim {
AdamOptions::AdamOptions(double learning_rate)
: learning_rate_(learning_rate) {}
@ -26,10 +26,10 @@ void Adam::step() {
p.grad() = p.grad() + options.weight_decay_ * p;
}
auto& exp_average = buffer_at(exp_average_buffers_, i);
auto& exp_average_sq = buffer_at(exp_average_sq_buffers_, i);
auto& exp_average = buffer_at(exp_average_buffers, i);
auto& exp_average_sq = buffer_at(exp_average_sq_buffers, i);
buffer_at(step_buffers_, i) += 1;
buffer_at(step_buffers, i) += 1;
exp_average.mul_(options.beta1_).add_(p.grad(), 1 - options.beta1_);
exp_average_sq.mul_(options.beta2_)
@ -37,15 +37,15 @@ void Adam::step() {
Tensor denom = exp_average_sq;
if (options.amsgrad_) {
auto& max_exp_average_sq = buffer_at(max_exp_average_sq_buffers_, i);
auto& max_exp_average_sq = buffer_at(max_exp_average_sq_buffers, i);
max_exp_average_sq = torch::max(max_exp_average_sq, exp_average_sq);
denom = max_exp_average_sq;
}
const auto bias_correction1 =
1 - std::pow(options.beta1_, buffer_at(step_buffers_, i));
1 - std::pow(options.beta1_, buffer_at(step_buffers, i));
const auto bias_correction2 =
1 - std::pow(options.beta2_, buffer_at(step_buffers_, i));
1 - std::pow(options.beta2_, buffer_at(step_buffers, i));
const auto step_size =
options.learning_rate_ * std::sqrt(bias_correction2) / bias_correction1;
@ -54,5 +54,12 @@ void Adam::step() {
}
}
void Adam::save(serialize::OutputArchive& archive) const {
serialize(*this, archive);
}
void Adam::load(serialize::InputArchive& archive) {
serialize(*this, archive);
}
} // namespace optim
} // namespace torch

View File

@ -2,6 +2,7 @@
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/serialize/archive.h>
#include <ATen/ATen.h>
@ -23,12 +24,14 @@ Tensor LBFGS::gather_flat_grad() {
return at::cat(views);
}
void LBFGS::add_grad(const torch::Scalar& step_size, const Tensor& update) {
void LBFGS::add_grad(const torch::Tensor& step_size, const Tensor& update) {
int64_t offset = 0;
for (auto& parameter : parameters_) {
int64_t numel = parameter.numel();
Tensor& pd = autograd::Variable(parameter).data();
pd.add_(update.slice(0, offset, offset + numel, 1).view_as(pd), step_size);
pd.add_(
update.slice(0, offset, offset + numel, 1).view_as(pd),
step_size.toCFloat());
offset += numel;
}
}
@ -109,9 +112,9 @@ torch::Tensor LBFGS::step(LossClosure closure) {
// reset initial guess for step size
if (n_iter == 1) {
t = at::_local_scalar(at::min(ONE, ONE / abs_grad_sum) * options.learning_rate_);
t = torch::min(ONE, ONE / abs_grad_sum) * options.learning_rate_;
} else {
t = options.learning_rate_;
t = at::tensor(options.learning_rate_, torch::kFloat32);
}
Tensor gtd = flat_grad.dot(d);
@ -141,17 +144,23 @@ torch::Tensor LBFGS::step(LossClosure closure) {
break;
} else if (gtd.toCFloat() > -options.tolerance_grad_) {
break;
} else if (
d.mul(t).abs_().sum().toCFloat() <=
options.tolerance_change_) {
} else if (d.mul(t).abs_().sum().toCFloat() <= options.tolerance_change_) {
break;
} else if (
std::abs(loss.toCFloat() - prev_loss.toFloat()) <
std::abs(loss.toCFloat() - prev_loss.toCFloat()) <
options.tolerance_change_) {
break;
}
}
return orig_loss;
}
void LBFGS::save(serialize::OutputArchive& archive) const {
serialize(*this, archive);
}
void LBFGS::load(serialize::InputArchive& archive) {
serialize(*this, archive);
}
} // namespace optim
} // namespace torch

View File

@ -3,13 +3,13 @@
#include <torch/nn/cursor.h>
#include <torch/tensor.h>
#include <string>
#include <utility>
#include <vector>
namespace torch {
namespace optim {
namespace detail {
OptimizerBase::OptimizerBase(std::vector<Tensor> parameters)
: parameters_(std::move(parameters)) {}
@ -47,6 +47,9 @@ std::vector<Tensor>& OptimizerBase::parameters() noexcept {
size_t OptimizerBase::size() const noexcept {
return parameters_.size();
}
void OptimizerBase::save(serialize::OutputArchive& archive) const {}
void OptimizerBase::load(serialize::InputArchive& archive) {}
} // namespace detail
} // namespace optim
} // namespace torch

View File

@ -1,6 +1,7 @@
#include <torch/optim/rmsprop.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/serialize/archive.h>
#include <torch/utils.h>
#include <ATen/ATen.h>
@ -26,13 +27,13 @@ void RMSprop::step() {
p.grad() = p.grad() + options.weight_decay_ * p;
}
auto square_average = buffer_at(square_average_buffers_, i);
auto square_average = buffer_at(square_average_buffers, i);
square_average.mul_(options.alpha_)
.addcmul_(p.grad(), p.grad(), 1.0 - options.alpha_);
Tensor average;
if (options.centered_ > 0) {
auto& grad_average = buffer_at(grad_average_buffers_, i);
auto& grad_average = buffer_at(grad_average_buffers, i);
grad_average.mul_(options.alpha_).add_(p.grad(), 1.0 - options.alpha_);
average = square_average.addcmul(grad_average, grad_average, -1.0)
.sqrt()
@ -43,7 +44,7 @@ void RMSprop::step() {
NoGradGuard guard;
if (options.momentum_ > 0) {
auto& momentum = buffer_at(momentum_buffers_, i);
auto& momentum = buffer_at(momentum_buffers, i);
momentum.mul_(options.momentum_).addcdiv_(p.grad(), average);
p.add_(momentum, -options.learning_rate_);
} else {
@ -51,5 +52,13 @@ void RMSprop::step() {
}
}
}
void RMSprop::save(serialize::OutputArchive& archive) const {
serialize(*this, archive);
}
void RMSprop::load(serialize::InputArchive& archive) {
serialize(*this, archive);
}
} // namespace optim
} // namespace torch

View File

@ -0,0 +1,39 @@
#include <torch/optim/serialize.h>
#include <torch/serialize/archive.h>
#include <torch/tensor.h>
#include <cstddef>
#include <cstdint>
#include <deque>
#include <string>
#include <vector>
namespace torch {
namespace optim {
namespace detail {
void serialize(
serialize::OutputArchive& archive,
const std::string& key,
const std::vector<int64_t>& steps) {
std::vector<torch::Tensor> tensors;
for (const auto& step : steps) {
tensors.push_back(torch::tensor(static_cast<int64_t>(step)));
}
serialize(archive, key, tensors);
}
void serialize(
serialize::InputArchive& archive,
const std::string& key,
std::vector<int64_t>& steps) {
std::vector<torch::Tensor> tensors;
serialize(archive, key, tensors);
steps.clear();
for (const auto& step : tensors) {
steps.push_back(step.toCLong());
}
}
} // namespace detail
} // namespace optim
} // namespace torch

View File

@ -1,6 +1,10 @@
#include <torch/optim/sgd.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/nn/pimpl.h>
#include <torch/optim/optimizer.h>
#include <torch/optim/serialize.h>
#include <torch/tensor.h>
#include <torch/utils.h>
#include <ATen/ATen.h>
@ -22,7 +26,7 @@ void SGD::step() {
auto update = options.learning_rate_ * p.grad();
if (options.momentum_ != 0) {
const auto dampening = iteration_ == 0 ? 1 : 1 - options.dampening_;
auto& momentum = buffer_at(momentum_buffers_, i);
auto& momentum = buffer_at(momentum_buffers, i);
momentum = (options.momentum_ * momentum) + (dampening * update);
if (options.nesterov_) {
// See github.com/lisa-lab/pylearn2/pull/136#issuecomment-10381617
@ -42,5 +46,13 @@ void SGD::step() {
}
iteration_ += 1;
}
void SGD::save(serialize::OutputArchive& archive) const {
detail::serialize(archive, "momentum_buffers", momentum_buffers);
}
void SGD::load(serialize::InputArchive& archive) {
detail::serialize(archive, "momentum_buffers", momentum_buffers);
}
} // namespace optim
} // namespace torch

View File

@ -0,0 +1,31 @@
#include <torch/serialize.h>
#include <torch/optim/optimizer.h>
#include <torch/serialize/archive.h>
#include <torch/tensor.h>
#include <string>
#include <utility>
namespace torch {
namespace serialize {
void save(const Tensor& tensor, OutputArchive& archive) {
archive.write("0", tensor);
}
void save(const optim::Optimizer& optimizer, OutputArchive& archive) {
optimizer.save(archive);
}
void load(optim::Optimizer& optimizer, InputArchive& archive) {
optimizer.load(archive);
}
} // namespace serialize
Tensor load(const std::string& filename) {
serialize::InputArchive archive = serialize::load_from_file(filename);
Tensor tensor;
archive.read("0", tensor);
return tensor;
}
} // namespace torch

View File

@ -0,0 +1,57 @@
#include <torch/serialize/input-archive.h>
#include <torch/tensor.h>
#include <torch/utils.h>
#include <torch/csrc/jit/import.h>
#include <torch/csrc/jit/script/module.h>
#include <ATen/Error.h>
#include <memory>
#include <string>
#include <utility>
namespace torch {
namespace serialize {
InputArchive::InputArchive()
: module_(std::make_shared<jit::script::Module>()) {}
void InputArchive::read(
const std::string& key,
Tensor& tensor,
bool is_buffer) {
auto* read_tensor = module_->find_parameter(key);
AT_CHECK(read_tensor != nullptr, "No such serialized tensor '", key, "'");
// clang-format off
AT_CHECK(
read_tensor->is_buffer == is_buffer,
"Expected deserialized tensor for key '", key,
"' to ", is_buffer ? "not " : "", "be a buffer, but it was not");
// clang-format on
if (tensor.defined()) {
torch::NoGradGuard guard;
tensor.set_(*read_tensor->slot());
} else {
tensor = std::move(*read_tensor->slot());
}
}
void InputArchive::read(const std::string& key, InputArchive& archive) {
if (auto* named_module = module_->find_module(key)) {
AT_ASSERT(named_module->module != nullptr);
archive.module_ = std::move(named_module->module);
} else {
AT_ERROR("No such serialized submodule: '", key, "'");
}
}
InputArchive::InputArchive(std::shared_ptr<jit::script::Module> module)
: module_(std::move(module)) {}
InputArchive load_from_file(const std::string& filename) {
return InputArchive(torch::jit::load(filename));
}
} // namespace serialize
} // namespace torch

View File

@ -0,0 +1,34 @@
#include <torch/serialize/output-archive.h>
#include <torch/tensor.h>
#include <torch/utils.h>
#include <torch/csrc/jit/export.h>
#include <torch/csrc/jit/script/module.h>
#include <memory>
#include <string>
namespace torch {
namespace serialize {
OutputArchive::OutputArchive()
: module_(std::make_shared<jit::script::Module>()) {}
void OutputArchive::write(
const std::string& key,
const Tensor& tensor,
bool is_buffer) {
module_->register_parameter(key, tensor, is_buffer);
}
void OutputArchive::write(
const std::string& key,
OutputArchive& nested_archive) {
module_->register_module(key, nested_archive.module_);
}
void save_to_file(const OutputArchive& archive, const std::string& filename) {
jit::ExportModule(*archive.module_, filename);
}
} // namespace serialize
} // namespace torch

View File

@ -334,6 +334,7 @@ at::Tensor ModuleDecoder::buildTensorCommon(
std::pair<std::shared_ptr<script::Module>, std::string> ModuleDecoder::parseFullName(
ModuleLookup module_lookup,
const std::string fullname) {
AT_ASSERT(!fullname.empty());
std::vector<std::string> vec;
std::stringstream ss(fullname);
std::string name;