mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Protobuf serialization (#11619)
Summary:
This PR serves two purposes:
1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general,
2. Add serialization to the ONNX/PyTorch proto format.
This is currently a rough prototype I coded up today, to get quick feedback.
For this I propose the following serialization interface within the C++ API:
```cpp
namespace torch { namespace serialize {
class Reader {
public:
virtual ~Reader() = default;
virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0;
virtual void finish() { }
};
class Writer {
public:
virtual ~Reader() = default;
virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0;
virtual void finish() { }
};
}} // namespace torch::serialize
```
There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to:
1. Provide a cereal-less serialization forward that we can ship and iterate on going forward,
2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft.
The user-facing API is (conceptually):
```cpp
void torch::save(const Module& module, Writer& writer);
void torch::save(const Optimizer& optimizer, Writer& writer);
void torch::read(Module& module, Reader& reader);
void torch::read(Optimizer& optimizer, Reader& reader);
```
with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader`
ebetica ezyang zdevito dzhulgakov
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11619
Differential Revision: D9984664
Pulled By: goldsborough
fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
This commit is contained in:
parent
30521a37ad
commit
d712a71741
|
|
@ -124,7 +124,6 @@ cmake_dependent_option(
|
|||
cmake_dependent_option(
|
||||
USE_GLOO_IBVERBS "Use Gloo IB verbs for distributed. Only available if USE_GLOO is on." OFF
|
||||
"USE_GLOO" OFF)
|
||||
option(TORCH_USE_CEREAL "Build the C++ API with Cereal for serialization support" OFF)
|
||||
|
||||
# Used when building Caffe2 through setup.py
|
||||
option(BUILDING_WITH_TORCH_LIBS "Tell cmake if Caffe2 is being built alongside torch libs" OFF)
|
||||
|
|
|
|||
|
|
@ -125,7 +125,6 @@ function (caffe2_print_configuration_summary)
|
|||
message(STATUS " USE_GLOO : ${USE_GLOO}")
|
||||
message(STATUS " USE_GLOO_IBVERBS : ${USE_GLOO_IBVERBS}")
|
||||
endif()
|
||||
message(STATUS " TORCH_USE_CEREAL : ${TORCH_USE_CEREAL}")
|
||||
|
||||
message(STATUS " Public Dependencies : ${Caffe2_PUBLIC_DEPENDENCY_LIBS}")
|
||||
message(STATUS " Private Dependencies : ${Caffe2_DEPENDENCY_LIBS}")
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@
|
|||
#define CATCH_CONFIG_PREFIX_ALL
|
||||
#include <catch.hpp>
|
||||
|
||||
// CATCH_REQUIRE_THROWS is not defined identically to REQUIRE_THROWS and causes warning;
|
||||
// define our own version that doesn't warn.
|
||||
#define _CATCH_REQUIRE_THROWS( ... ) INTERNAL_CATCH_THROWS( "CATCH_REQUIRE_THROWS", Catch::ResultDisposition::Normal, __VA_ARGS__ )
|
||||
// CATCH_REQUIRE_THROWS is not defined identically to REQUIRE_THROWS and causes
|
||||
// warning; define our own version that doesn't warn.
|
||||
#define _CATCH_REQUIRE_THROWS(...) \
|
||||
INTERNAL_CATCH_THROWS( \
|
||||
"CATCH_REQUIRE_THROWS", Catch::ResultDisposition::Normal, __VA_ARGS__)
|
||||
|
|
|
|||
|
|
@ -1,339 +0,0 @@
|
|||
#include "catch_utils.hpp"
|
||||
|
||||
#include <torch/nn/modules/functional.h>
|
||||
#include <torch/nn/modules/linear.h>
|
||||
#include <torch/nn/modules/sequential.h>
|
||||
#include <torch/optim/optimizer.h>
|
||||
#include <torch/optim/sgd.h>
|
||||
#include <torch/serialization.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <test/cpp/api/util.h>
|
||||
|
||||
#include <cereal/archives/portable_binary.hpp>
|
||||
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using namespace torch::nn;
|
||||
|
||||
namespace {
|
||||
Sequential xor_model() {
|
||||
return Sequential(
|
||||
Linear(2, 8),
|
||||
Functional(at::sigmoid),
|
||||
Linear(8, 1),
|
||||
Functional(at::sigmoid));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
CATCH_TEST_CASE("serialization") {
|
||||
torch::manual_seed(0);
|
||||
CATCH_SECTION("undefined") {
|
||||
auto x = torch::Tensor();
|
||||
|
||||
CATCH_REQUIRE(!x.defined());
|
||||
|
||||
auto y = torch::randn({5});
|
||||
|
||||
std::stringstream ss;
|
||||
torch::save(ss, &x);
|
||||
torch::load(ss, &y);
|
||||
|
||||
CATCH_REQUIRE(!y.defined());
|
||||
}
|
||||
|
||||
CATCH_SECTION("cputypes") {
|
||||
for (int i = 0; i < static_cast<int>(torch::Dtype::NumOptions); i++) {
|
||||
if (i == static_cast<int>(torch::Dtype::Half)) {
|
||||
// XXX can't serialize half tensors at the moment since contiguous() is
|
||||
// not implemented for this type;
|
||||
continue;
|
||||
} else if (at::isComplexType(static_cast<torch::Dtype>(i))) {
|
||||
// Not supported yet
|
||||
continue;
|
||||
} else if (i == static_cast<int>(torch::Dtype::Undefined)) {
|
||||
// We can't construct a tensor for this type. This is tested in
|
||||
// serialization/undefined anyway.
|
||||
continue;
|
||||
}
|
||||
|
||||
auto x = torch::ones(
|
||||
{5, 5}, static_cast<torch::Dtype>(i));
|
||||
auto y = torch::empty({});
|
||||
|
||||
std::stringstream ss;
|
||||
torch::save(ss, &x);
|
||||
torch::load(ss, &y);
|
||||
|
||||
CATCH_REQUIRE(y.defined());
|
||||
CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
|
||||
if (torch::isIntegralType(static_cast<torch::Dtype>(i))) {
|
||||
CATCH_REQUIRE(x.equal(y));
|
||||
} else {
|
||||
CATCH_REQUIRE(x.allclose(y));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CATCH_SECTION("binary") {
|
||||
auto x = torch::randn({5, 5});
|
||||
auto y = torch::Tensor();
|
||||
|
||||
std::stringstream ss;
|
||||
{
|
||||
cereal::BinaryOutputArchive archive(ss);
|
||||
archive(x);
|
||||
}
|
||||
{
|
||||
cereal::BinaryInputArchive archive(ss);
|
||||
archive(y);
|
||||
}
|
||||
|
||||
CATCH_REQUIRE(y.defined());
|
||||
CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
|
||||
CATCH_REQUIRE(x.allclose(y));
|
||||
}
|
||||
CATCH_SECTION("portable_binary") {
|
||||
auto x = torch::randn({5, 5});
|
||||
auto y = torch::Tensor();
|
||||
|
||||
std::stringstream ss;
|
||||
{
|
||||
cereal::PortableBinaryOutputArchive archive(ss);
|
||||
archive(x);
|
||||
}
|
||||
{
|
||||
cereal::PortableBinaryInputArchive archive(ss);
|
||||
archive(y);
|
||||
}
|
||||
|
||||
CATCH_REQUIRE(y.defined());
|
||||
CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
|
||||
CATCH_REQUIRE(x.allclose(y));
|
||||
}
|
||||
|
||||
CATCH_SECTION("resized") {
|
||||
auto x = torch::randn({11, 5});
|
||||
x.resize_({5, 5});
|
||||
auto y = torch::Tensor();
|
||||
|
||||
std::stringstream ss;
|
||||
{
|
||||
cereal::BinaryOutputArchive archive(ss);
|
||||
archive(x);
|
||||
}
|
||||
{
|
||||
cereal::BinaryInputArchive archive(ss);
|
||||
archive(y);
|
||||
}
|
||||
|
||||
CATCH_REQUIRE(y.defined());
|
||||
CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
|
||||
CATCH_REQUIRE(x.allclose(y));
|
||||
}
|
||||
CATCH_SECTION("sliced") {
|
||||
auto x = torch::randn({11, 5});
|
||||
x = x.slice(0, 1, 3);
|
||||
auto y = torch::Tensor();
|
||||
|
||||
std::stringstream ss;
|
||||
{
|
||||
cereal::BinaryOutputArchive archive(ss);
|
||||
archive(x);
|
||||
}
|
||||
{
|
||||
cereal::BinaryInputArchive archive(ss);
|
||||
archive(y);
|
||||
}
|
||||
|
||||
CATCH_REQUIRE(y.defined());
|
||||
CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
|
||||
CATCH_REQUIRE(x.allclose(y));
|
||||
}
|
||||
|
||||
CATCH_SECTION("noncontig") {
|
||||
auto x = torch::randn({11, 5});
|
||||
x = x.slice(1, 1, 4);
|
||||
auto y = torch::Tensor();
|
||||
|
||||
std::stringstream ss;
|
||||
{
|
||||
cereal::BinaryOutputArchive archive(ss);
|
||||
archive(x);
|
||||
}
|
||||
{
|
||||
cereal::BinaryInputArchive archive(ss);
|
||||
archive(y);
|
||||
}
|
||||
|
||||
CATCH_REQUIRE(y.defined());
|
||||
CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
|
||||
CATCH_REQUIRE(x.allclose(y));
|
||||
}
|
||||
|
||||
CATCH_SECTION("xor") {
|
||||
// We better be able to save and load a XOR model!
|
||||
auto getLoss = [](Sequential model, uint32_t batch_size) {
|
||||
auto inputs = torch::empty({batch_size, 2});
|
||||
auto labels = torch::empty({batch_size});
|
||||
for (size_t i = 0; i < batch_size; i++) {
|
||||
inputs[i] = torch::randint(2, {2}, torch::kInt64);
|
||||
labels[i] = inputs[i][0].toCLong() ^ inputs[i][1].toCLong();
|
||||
}
|
||||
auto x = model->forward<torch::Tensor>(inputs);
|
||||
return torch::binary_cross_entropy(x, labels);
|
||||
};
|
||||
|
||||
auto model = xor_model();
|
||||
auto model2 = xor_model();
|
||||
auto model3 = xor_model();
|
||||
auto optimizer = torch::optim::SGD(
|
||||
model->parameters(),
|
||||
torch::optim::SGDOptions(1e-1)
|
||||
.momentum(0.9)
|
||||
.nesterov(true)
|
||||
.weight_decay(1e-6));
|
||||
|
||||
float running_loss = 1;
|
||||
int epoch = 0;
|
||||
while (running_loss > 0.1) {
|
||||
torch::Tensor loss = getLoss(model, 4);
|
||||
optimizer.zero_grad();
|
||||
loss.backward();
|
||||
optimizer.step();
|
||||
|
||||
running_loss = running_loss * 0.99 + loss.sum().toCFloat() * 0.01;
|
||||
CATCH_REQUIRE(epoch < 3000);
|
||||
epoch++;
|
||||
}
|
||||
|
||||
std::stringstream ss;
|
||||
torch::save(ss, model);
|
||||
torch::load(ss, model2);
|
||||
|
||||
auto loss = getLoss(model2, 100);
|
||||
CATCH_REQUIRE(loss.toCFloat() < 0.1);
|
||||
}
|
||||
|
||||
CATCH_SECTION("optim") {
|
||||
auto model1 = Linear(5, 2);
|
||||
auto model2 = Linear(5, 2);
|
||||
auto model3 = Linear(5, 2);
|
||||
|
||||
// Models 1, 2, 3 will have the same params
|
||||
std::stringstream ss;
|
||||
torch::save(ss, model1.get());
|
||||
torch::load(ss, model2.get());
|
||||
ss.seekg(0, std::ios::beg);
|
||||
torch::load(ss, model3.get());
|
||||
|
||||
auto param1 = model1->parameters();
|
||||
auto param2 = model2->parameters();
|
||||
auto param3 = model3->parameters();
|
||||
for (const auto& p : param1) {
|
||||
CATCH_REQUIRE(param1[p.key].allclose(param2[p.key]));
|
||||
CATCH_REQUIRE(param2[p.key].allclose(param3[p.key]));
|
||||
}
|
||||
|
||||
// Make some optimizers with momentum (and thus state)
|
||||
auto optim1 = torch::optim::SGD(
|
||||
model1->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
|
||||
auto optim2 = torch::optim::SGD(
|
||||
model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
|
||||
auto optim2_2 = torch::optim::SGD(
|
||||
model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
|
||||
auto optim3 = torch::optim::SGD(
|
||||
model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
|
||||
auto optim3_2 = torch::optim::SGD(
|
||||
model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
|
||||
|
||||
auto x = torch::ones({10, 5});
|
||||
|
||||
auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
|
||||
optimizer.zero_grad();
|
||||
auto y = model->forward(x).sum();
|
||||
y.backward();
|
||||
optimizer.step();
|
||||
};
|
||||
|
||||
// Do 2 steps of model1
|
||||
step(optim1, model1);
|
||||
step(optim1, model1);
|
||||
|
||||
// Do 2 steps of model 2 without saving the optimizer
|
||||
step(optim2, model2);
|
||||
step(optim2_2, model2);
|
||||
|
||||
// Do 2 steps of model 3 while saving the optimizer
|
||||
step(optim3, model3);
|
||||
ss.clear();
|
||||
torch::save(ss, &optim3);
|
||||
torch::load(ss, &optim3_2);
|
||||
step(optim3_2, model3);
|
||||
|
||||
param1 = model1->parameters();
|
||||
param2 = model2->parameters();
|
||||
param3 = model3->parameters();
|
||||
for (const auto& p : param1) {
|
||||
const auto& name = p.key;
|
||||
// Model 1 and 3 should be the same
|
||||
CATCH_REQUIRE(param1[name].norm().toCFloat() == param3[name].norm().toCFloat());
|
||||
CATCH_REQUIRE(param1[name].norm().toCFloat() != param2[name].norm().toCFloat());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("serialization_cuda", "[cuda]") {
|
||||
torch::manual_seed(0);
|
||||
// We better be able to save and load a XOR model!
|
||||
auto getLoss = [](Sequential model, uint32_t batch_size) {
|
||||
auto inputs = torch::empty({batch_size, 2});
|
||||
auto labels = torch::empty({batch_size});
|
||||
for (size_t i = 0; i < batch_size; i++) {
|
||||
inputs[i] = torch::randint(2, {2}, torch::kInt64);
|
||||
labels[i] = inputs[i][0].toCLong() ^ inputs[i][1].toCLong();
|
||||
}
|
||||
auto x = model->forward<torch::Tensor>(inputs);
|
||||
return torch::binary_cross_entropy(x, labels);
|
||||
};
|
||||
|
||||
auto model = xor_model();
|
||||
auto model2 = xor_model();
|
||||
auto model3 = xor_model();
|
||||
auto optimizer = torch::optim::SGD(
|
||||
model->parameters(),
|
||||
torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay(
|
||||
1e-6));
|
||||
|
||||
float running_loss = 1;
|
||||
int epoch = 0;
|
||||
while (running_loss > 0.1) {
|
||||
torch::Tensor loss = getLoss(model, 4);
|
||||
optimizer.zero_grad();
|
||||
loss.backward();
|
||||
optimizer.step();
|
||||
|
||||
running_loss = running_loss * 0.99 + loss.sum().toCFloat() * 0.01;
|
||||
CATCH_REQUIRE(epoch < 3000);
|
||||
epoch++;
|
||||
}
|
||||
|
||||
std::stringstream ss;
|
||||
torch::save(ss, model);
|
||||
torch::load(ss, model2);
|
||||
|
||||
auto loss = getLoss(model2, 100);
|
||||
CATCH_REQUIRE(loss.toCFloat() < 0.1);
|
||||
|
||||
model2->to(torch::kCUDA);
|
||||
ss.clear();
|
||||
torch::save(ss, model2);
|
||||
torch::load(ss, model3);
|
||||
|
||||
loss = getLoss(model3, 100);
|
||||
CATCH_REQUIRE(loss.toCFloat() < 0.1);
|
||||
}
|
||||
246
test/cpp/api/serialize.cpp
Normal file
246
test/cpp/api/serialize.cpp
Normal file
|
|
@ -0,0 +1,246 @@
|
|||
#include <test/cpp/api/catch_utils.hpp>
|
||||
|
||||
#include <torch/nn/modules/functional.h>
|
||||
#include <torch/nn/modules/linear.h>
|
||||
#include <torch/nn/modules/sequential.h>
|
||||
#include <torch/optim/optimizer.h>
|
||||
#include <torch/optim/sgd.h>
|
||||
#include <torch/serialize.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <test/cpp/api/util.h>
|
||||
|
||||
#include <cstdio>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using namespace torch::nn;
|
||||
using namespace torch::serialize;
|
||||
|
||||
namespace {
|
||||
Sequential xor_model() {
|
||||
return Sequential(
|
||||
Linear(2, 8),
|
||||
Functional(at::sigmoid),
|
||||
Linear(8, 1),
|
||||
Functional(at::sigmoid));
|
||||
}
|
||||
|
||||
torch::Tensor save_and_load(torch::Tensor input) {
|
||||
torch::test::TempFile tempfile;
|
||||
torch::save(input, tempfile.str());
|
||||
return torch::load(tempfile.str());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
CATCH_TEST_CASE("Serialize/Default/Basic") {
|
||||
torch::manual_seed(0);
|
||||
|
||||
auto x = torch::randn({5, 5});
|
||||
auto y = save_and_load(x);
|
||||
|
||||
CATCH_REQUIRE(y.defined());
|
||||
CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
|
||||
CATCH_REQUIRE(x.allclose(y));
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("Serialize/Default/Resized") {
|
||||
torch::manual_seed(0);
|
||||
|
||||
auto x = torch::randn({11, 5});
|
||||
x.resize_({5, 5});
|
||||
auto y = save_and_load(x);
|
||||
|
||||
CATCH_REQUIRE(y.defined());
|
||||
CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
|
||||
CATCH_REQUIRE(x.allclose(y));
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("Serialize/Default/Sliced") {
|
||||
torch::manual_seed(0);
|
||||
|
||||
auto x = torch::randn({11, 5});
|
||||
x = x.slice(0, 1, 5);
|
||||
auto y = save_and_load(x);
|
||||
|
||||
CATCH_REQUIRE(y.defined());
|
||||
CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
|
||||
CATCH_REQUIRE(x.allclose(y));
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("Serialize/Default/NonContiguous") {
|
||||
torch::manual_seed(0);
|
||||
|
||||
auto x = torch::randn({11, 5});
|
||||
x = x.slice(1, 1, 4);
|
||||
auto y = save_and_load(x);
|
||||
|
||||
CATCH_REQUIRE(y.defined());
|
||||
CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
|
||||
CATCH_REQUIRE(x.allclose(y));
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("Serialize/Default/XOR") {
|
||||
// We better be able to save and load an XOR model!
|
||||
auto getLoss = [](Sequential model, uint32_t batch_size) {
|
||||
auto inputs = torch::empty({batch_size, 2});
|
||||
auto labels = torch::empty({batch_size});
|
||||
for (size_t i = 0; i < batch_size; i++) {
|
||||
inputs[i] = torch::randint(2, {2}, torch::kInt64);
|
||||
labels[i] = inputs[i][0].toCLong() ^ inputs[i][1].toCLong();
|
||||
}
|
||||
auto x = model->forward<torch::Tensor>(inputs);
|
||||
return torch::binary_cross_entropy(x, labels);
|
||||
};
|
||||
|
||||
auto model = xor_model();
|
||||
auto model2 = xor_model();
|
||||
auto model3 = xor_model();
|
||||
auto optimizer = torch::optim::SGD(
|
||||
model->parameters(),
|
||||
torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay(
|
||||
1e-6));
|
||||
|
||||
float running_loss = 1;
|
||||
int epoch = 0;
|
||||
while (running_loss > 0.1) {
|
||||
torch::Tensor loss = getLoss(model, 4);
|
||||
optimizer.zero_grad();
|
||||
loss.backward();
|
||||
optimizer.step();
|
||||
|
||||
running_loss = running_loss * 0.99 + loss.sum().toCFloat() * 0.01;
|
||||
CATCH_REQUIRE(epoch < 3000);
|
||||
epoch++;
|
||||
}
|
||||
|
||||
torch::test::TempFile tempfile;
|
||||
torch::save(model, tempfile.str());
|
||||
torch::load(model2, tempfile.str());
|
||||
|
||||
auto loss = getLoss(model2, 100);
|
||||
CATCH_REQUIRE(loss.toCFloat() < 0.1);
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("Serialize/Default/Optim") {
|
||||
auto model1 = Linear(5, 2);
|
||||
auto model2 = Linear(5, 2);
|
||||
auto model3 = Linear(5, 2);
|
||||
|
||||
// Models 1, 2, 3 will have the same parameters.
|
||||
torch::test::TempFile model_tempfile;
|
||||
torch::save(model1, model_tempfile.str());
|
||||
torch::load(model2, model_tempfile.str());
|
||||
torch::load(model3, model_tempfile.str());
|
||||
|
||||
auto param1 = model1->parameters();
|
||||
auto param2 = model2->parameters();
|
||||
auto param3 = model3->parameters();
|
||||
for (const auto& p : param1) {
|
||||
CATCH_REQUIRE(param1[p.key].allclose(param2[p.key]));
|
||||
CATCH_REQUIRE(param2[p.key].allclose(param3[p.key]));
|
||||
}
|
||||
|
||||
// Make some optimizers with momentum (and thus state)
|
||||
auto optim1 = torch::optim::SGD(
|
||||
model1->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
|
||||
auto optim2 = torch::optim::SGD(
|
||||
model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
|
||||
auto optim2_2 = torch::optim::SGD(
|
||||
model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
|
||||
auto optim3 = torch::optim::SGD(
|
||||
model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
|
||||
auto optim3_2 = torch::optim::SGD(
|
||||
model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
|
||||
|
||||
auto x = torch::ones({10, 5});
|
||||
|
||||
auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
|
||||
optimizer.zero_grad();
|
||||
auto y = model->forward(x).sum();
|
||||
y.backward();
|
||||
optimizer.step();
|
||||
};
|
||||
|
||||
// Do 2 steps of model1
|
||||
step(optim1, model1);
|
||||
step(optim1, model1);
|
||||
|
||||
// Do 2 steps of model 2 without saving the optimizer
|
||||
step(optim2, model2);
|
||||
step(optim2_2, model2);
|
||||
|
||||
// Do 2 steps of model 3 while saving the optimizer
|
||||
step(optim3, model3);
|
||||
|
||||
torch::test::TempFile optim_tempfile;
|
||||
torch::save(optim3, optim_tempfile.str());
|
||||
torch::load(optim3_2, optim_tempfile.str());
|
||||
step(optim3_2, model3);
|
||||
|
||||
param1 = model1->parameters();
|
||||
param2 = model2->parameters();
|
||||
param3 = model3->parameters();
|
||||
for (const auto& p : param1) {
|
||||
const auto& name = p.key;
|
||||
// Model 1 and 3 should be the same
|
||||
CATCH_REQUIRE(
|
||||
param1[name].norm().toCFloat() == param3[name].norm().toCFloat());
|
||||
CATCH_REQUIRE(
|
||||
param1[name].norm().toCFloat() != param2[name].norm().toCFloat());
|
||||
}
|
||||
}
|
||||
|
||||
CATCH_TEST_CASE("Serialize/Default/CUDA", "[cuda]") {
|
||||
torch::manual_seed(0);
|
||||
// We better be able to save and load a XOR model!
|
||||
auto getLoss = [](Sequential model, uint32_t batch_size) {
|
||||
auto inputs = torch::empty({batch_size, 2});
|
||||
auto labels = torch::empty({batch_size});
|
||||
for (size_t i = 0; i < batch_size; i++) {
|
||||
inputs[i] = torch::randint(2, {2}, torch::kInt64);
|
||||
labels[i] = inputs[i][0].toCLong() ^ inputs[i][1].toCLong();
|
||||
}
|
||||
auto x = model->forward<torch::Tensor>(inputs);
|
||||
return torch::binary_cross_entropy(x, labels);
|
||||
};
|
||||
|
||||
auto model = xor_model();
|
||||
auto model2 = xor_model();
|
||||
auto model3 = xor_model();
|
||||
auto optimizer = torch::optim::SGD(
|
||||
model->parameters(),
|
||||
torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay(
|
||||
1e-6));
|
||||
|
||||
float running_loss = 1;
|
||||
int epoch = 0;
|
||||
while (running_loss > 0.1) {
|
||||
torch::Tensor loss = getLoss(model, 4);
|
||||
optimizer.zero_grad();
|
||||
loss.backward();
|
||||
optimizer.step();
|
||||
|
||||
running_loss = running_loss * 0.99 + loss.sum().toCFloat() * 0.01;
|
||||
CATCH_REQUIRE(epoch < 3000);
|
||||
epoch++;
|
||||
}
|
||||
|
||||
torch::test::TempFile tempfile;
|
||||
torch::save(model, tempfile.str());
|
||||
torch::load(model2, tempfile.str());
|
||||
|
||||
auto loss = getLoss(model2, 100);
|
||||
CATCH_REQUIRE(loss.toCFloat() < 0.1);
|
||||
|
||||
model2->to(torch::kCUDA);
|
||||
torch::test::TempFile tempfile2;
|
||||
torch::save(model2, tempfile2.str());
|
||||
torch::load(model3, tempfile2.str());
|
||||
|
||||
loss = getLoss(model3, 100);
|
||||
CATCH_REQUIRE(loss.toCFloat() < 0.1);
|
||||
}
|
||||
|
|
@ -1,10 +1,17 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/nn/cloneable.h>
|
||||
#include <torch/tensor.h>
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#ifndef WIN32
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
namespace torch {
|
||||
namespace test {
|
||||
|
||||
|
|
@ -25,5 +32,36 @@ class SimpleContainer : public nn::Cloneable<SimpleContainer> {
|
|||
inline bool pointer_equal(at::Tensor first, at::Tensor second) {
|
||||
return first.data<float>() == second.data<float>();
|
||||
}
|
||||
|
||||
#ifdef WIN32
|
||||
struct TempFile {
|
||||
TempFile() : filename_(std::tmpnam(nullptr)) {}
|
||||
const std::string& str() const {
|
||||
return filename_;
|
||||
}
|
||||
std::string filename_;
|
||||
};
|
||||
#else
|
||||
struct TempFile {
|
||||
TempFile() {
|
||||
// http://pubs.opengroup.org/onlinepubs/009695399/functions/mkstemp.html
|
||||
char filename[] = "/tmp/fileXXXXXX";
|
||||
fd_ = mkstemp(filename);
|
||||
AT_CHECK(fd_ != -1, "Error creating tempfile");
|
||||
filename_.assign(filename);
|
||||
}
|
||||
|
||||
~TempFile() {
|
||||
close(fd_);
|
||||
}
|
||||
|
||||
const std::string& str() const {
|
||||
return filename_;
|
||||
}
|
||||
|
||||
std::string filename_;
|
||||
int fd_;
|
||||
};
|
||||
#endif
|
||||
} // namespace test
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ from setup_helpers.cuda import USE_CUDA
|
|||
if __name__ == '__main__':
|
||||
# Placeholder for future interface. For now just gives a nice -h.
|
||||
parser = argparse.ArgumentParser(description='Build libtorch')
|
||||
parser.add_argument('--use-cereal', action='store_true')
|
||||
options = parser.parse_args()
|
||||
|
||||
os.environ['BUILD_TORCH'] = 'ON'
|
||||
|
|
@ -25,8 +24,6 @@ if __name__ == '__main__':
|
|||
command.append('--use-cuda')
|
||||
if os.environ.get('USE_CUDA_STATIC_LINK', False):
|
||||
command.append('--cuda-static-link')
|
||||
if options.use_cereal:
|
||||
command.append('--use-cereal')
|
||||
command.append('caffe2')
|
||||
|
||||
sys.stdout.flush()
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ USE_NNPACK=0
|
|||
USE_MKLDNN=0
|
||||
USE_GLOO_IBVERBS=0
|
||||
CAFFE2_STATIC_LINK_CUDA=0
|
||||
TORCH_USE_CEREAL=0
|
||||
RERUN_CMAKE=1
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
|
|
@ -47,9 +46,6 @@ while [[ $# -gt 0 ]]; do
|
|||
--cuda-static-link)
|
||||
CAFFE2_STATIC_LINK_CUDA=1
|
||||
;;
|
||||
--use-cereal)
|
||||
TORCH_USE_CEREAL=1
|
||||
;;
|
||||
*)
|
||||
break
|
||||
;;
|
||||
|
|
@ -194,7 +190,6 @@ function build() {
|
|||
-DTHCUNN_SO_VERSION=1 \
|
||||
-DTHD_SO_VERSION=1 \
|
||||
-DUSE_CUDA=$USE_CUDA \
|
||||
-DTORCH_USE_CEREAL=$TORCH_USE_CEREAL \
|
||||
-DBUILD_EXAMPLES=OFF \
|
||||
-DBUILD_TEST=$BUILD_TEST \
|
||||
-DNO_NNPACK=$((1-$USE_NNPACK)) \
|
||||
|
|
|
|||
|
|
@ -220,8 +220,8 @@ CONFIGURE_FILE(
|
|||
|
||||
if (NOT NO_API AND NOT USE_ROCM)
|
||||
list(APPEND TORCH_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/api/src/utils.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/cuda.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/jit.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/cursor.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/init.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/module.cpp
|
||||
|
|
@ -232,13 +232,17 @@ if (NOT NO_API AND NOT USE_ROCM)
|
|||
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/functional.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/linear.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/rnn.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/optim/optimizer.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/optim/adam.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/optim/adagrad.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/optim/adam.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/optim/lbfgs.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/optim/optimizer.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/optim/rmsprop.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/optim/serialize.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/optim/sgd.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/jit.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/serialize.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/serialize/input-archive.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/serialize/output-archive.cpp
|
||||
${TORCH_SRC_DIR}/csrc/api/src/utils.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
|
|
@ -330,13 +334,6 @@ if (NOT NO_API AND NOT USE_ROCM)
|
|||
target_include_directories(torch PUBLIC
|
||||
${TORCH_SRC_DIR}/csrc/api
|
||||
${TORCH_SRC_DIR}/csrc/api/include)
|
||||
|
||||
if (TORCH_USE_CEREAL)
|
||||
target_compile_definitions(torch PUBLIC TORCH_USE_CEREAL)
|
||||
# SYSTEM headers are included with -isystem and thus do not trigger warnings.
|
||||
target_include_directories(torch SYSTEM PUBLIC
|
||||
"${TORCH_ROOT}/third_party/cereal/include") # For cereal/
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(USE_CUDA)
|
||||
|
|
@ -445,6 +442,7 @@ if (BUILD_TEST AND NOT NO_API AND NOT USE_ROCM)
|
|||
${TORCH_API_TEST_DIR}/any.cpp
|
||||
${TORCH_API_TEST_DIR}/cursor.cpp
|
||||
${TORCH_API_TEST_DIR}/integration.cpp
|
||||
${TORCH_API_TEST_DIR}/jit.cpp
|
||||
${TORCH_API_TEST_DIR}/main.cpp
|
||||
${TORCH_API_TEST_DIR}/misc.cpp
|
||||
${TORCH_API_TEST_DIR}/module.cpp
|
||||
|
|
@ -453,17 +451,13 @@ if (BUILD_TEST AND NOT NO_API AND NOT USE_ROCM)
|
|||
${TORCH_API_TEST_DIR}/parallel.cpp
|
||||
${TORCH_API_TEST_DIR}/rnn.cpp
|
||||
${TORCH_API_TEST_DIR}/sequential.cpp
|
||||
${TORCH_API_TEST_DIR}/serialize.cpp
|
||||
${TORCH_API_TEST_DIR}/tensor_cuda.cpp
|
||||
${TORCH_API_TEST_DIR}/tensor.cpp
|
||||
${TORCH_API_TEST_DIR}/jit.cpp
|
||||
${TORCH_API_TEST_DIR}/tensor_options.cpp
|
||||
${TORCH_API_TEST_DIR}/tensor_options_cuda.cpp
|
||||
${TORCH_API_TEST_DIR}/tensor_options.cpp
|
||||
${TORCH_API_TEST_DIR}/tensor.cpp
|
||||
)
|
||||
|
||||
if (TORCH_USE_CEREAL)
|
||||
list(APPEND TORCH_API_TEST_SOURCES ${TORCH_API_TEST_DIR}/serialization.cpp)
|
||||
endif()
|
||||
|
||||
add_executable(test_api ${TORCH_API_TEST_SOURCES})
|
||||
|
||||
target_include_directories(test_api
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/utils/variadic.h>
|
||||
#include <torch/tensor.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <type_traits>
|
||||
|
||||
|
|
@ -52,7 +55,8 @@ inline constexpr bool check_not_lvalue_references<void>() {
|
|||
|
||||
/// A type trait whose `value` member is true if `M` derives from `Module`.
|
||||
template <typename M>
|
||||
using is_module = std::is_base_of<torch::nn::Module, typename std::decay<M>::type>;
|
||||
using is_module =
|
||||
std::is_base_of<torch::nn::Module, typename std::decay<M>::type>;
|
||||
|
||||
template <typename M, typename T = void>
|
||||
using enable_if_module_t =
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#include <torch/detail/ordered_dict.h>
|
||||
#include <torch/nn/cursor.h>
|
||||
#include <torch/nn/pimpl.h>
|
||||
#include <torch/serialize/archive.h>
|
||||
#include <torch/tensor.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
|
@ -208,14 +209,6 @@ class Module {
|
|||
/// Recursively zeros out the `grad` value of each registered parameter.
|
||||
virtual void zero_grad();
|
||||
|
||||
/// Serializes the `Module`.
|
||||
template <class Archive>
|
||||
void save(Archive& ar) const;
|
||||
|
||||
/// Deserializes the `Module`.
|
||||
template <class Archive>
|
||||
void load(Archive& ar);
|
||||
|
||||
/// Attempts to cast this `Module` to the given `ModuleType`.
|
||||
///
|
||||
/// This method is useful when calling `apply()` on a `ModuleCursor`.
|
||||
|
|
@ -255,6 +248,12 @@ class Module {
|
|||
typename = torch::detail::disable_if_module_holder_t<ModuleType>>
|
||||
ModuleType* as() noexcept;
|
||||
|
||||
/// Serializes the `Module` into the given `OutputArchive`.
|
||||
virtual void save(serialize::OutputArchive& archive) const;
|
||||
|
||||
/// Deserializes the `Module` from the given `InputArchive`.
|
||||
virtual void load(serialize::InputArchive& archive);
|
||||
|
||||
protected:
|
||||
/// Registers a parameter with this `Module`.
|
||||
///
|
||||
|
|
@ -359,28 +358,6 @@ class Module {
|
|||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ nn::Module ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
template <class Archive>
|
||||
void Module::save(Archive& ar) const {
|
||||
auto params = parameters();
|
||||
size_t size = params.size();
|
||||
ar(size);
|
||||
for (auto& p : params) {
|
||||
ar(p.key, p.value);
|
||||
}
|
||||
}
|
||||
|
||||
template <class Archive>
|
||||
void Module::load(Archive& ar) {
|
||||
auto params = parameters();
|
||||
size_t size;
|
||||
ar(size);
|
||||
std::string name;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
ar(name);
|
||||
ar(params[name]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ModuleType>
|
||||
typename ModuleType::ContainedType* Module::as() noexcept {
|
||||
// Use the contained type of the `ModuleHolder`, e.g. `LinearImpl` for
|
||||
|
|
|
|||
|
|
@ -1,15 +1,20 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/nn/module.h>
|
||||
#include <torch/nn/pimpl.h>
|
||||
#include <torch/optim/optimizer.h>
|
||||
#include <torch/serialization.h>
|
||||
#include <torch/optim/serialize.h>
|
||||
#include <torch/tensor.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace serialize {
|
||||
class OutputArchive;
|
||||
class InputArchive;
|
||||
} // namespace serialize
|
||||
} // namespace torch
|
||||
|
||||
namespace torch {
|
||||
namespace optim {
|
||||
|
||||
|
|
@ -33,29 +38,20 @@ class Adagrad : public Optimizer {
|
|||
|
||||
AdagradOptions options;
|
||||
|
||||
template <class Archive>
|
||||
void serialize(Archive& ar) {
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
ar(CEREAL_NVP(sum_));
|
||||
ar(CEREAL_NVP(step_));
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
}
|
||||
void save(serialize::OutputArchive& archive) const override;
|
||||
void load(serialize::InputArchive& archive) override;
|
||||
|
||||
std::vector<Tensor> sum_buffers;
|
||||
std::vector<int64_t> step_buffers;
|
||||
|
||||
private:
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
friend class cereal::access;
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
Adagrad() : options(0) {}
|
||||
|
||||
std::vector<Tensor> sum_;
|
||||
std::vector<int64_t> step_;
|
||||
template <typename Self, typename Archive>
|
||||
static void serialize(Self& self, Archive& archive) {
|
||||
TORCH_OPTIM_SERIALIZE(sum_buffers);
|
||||
TORCH_OPTIM_SERIALIZE(step_buffers);
|
||||
}
|
||||
};
|
||||
} // namespace optim
|
||||
} // namespace torch
|
||||
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
CEREAL_REGISTER_TYPE(torch::optim::Adagrad);
|
||||
CEREAL_REGISTER_POLYMORPHIC_RELATION(
|
||||
torch::optim::Optimizer,
|
||||
torch::optim::Adagrad);
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
|
|
|
|||
|
|
@ -3,13 +3,18 @@
|
|||
#include <torch/arg.h>
|
||||
#include <torch/nn/module.h>
|
||||
#include <torch/optim/optimizer.h>
|
||||
#include <torch/serialization.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/optim/serialize.h>
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace serialize {
|
||||
class OutputArchive;
|
||||
class InputArchive;
|
||||
} // namespace serialize
|
||||
} // namespace torch
|
||||
|
||||
namespace torch {
|
||||
namespace optim {
|
||||
|
||||
|
|
@ -32,35 +37,26 @@ class Adam : public Optimizer {
|
|||
|
||||
void step() override;
|
||||
|
||||
template <class Archive>
|
||||
void serialize(Archive& ar) {
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
ar(CEREAL_NVP(step_buffers_),
|
||||
CEREAL_NVP(exp_average_buffers_),
|
||||
CEREAL_NVP(exp_average_sq_buffers_),
|
||||
CEREAL_NVP(max_exp_average_sq_buffers_));
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
}
|
||||
void save(serialize::OutputArchive& archive) const override;
|
||||
void load(serialize::InputArchive& archive) override;
|
||||
|
||||
AdamOptions options;
|
||||
|
||||
std::vector<int64_t> step_buffers;
|
||||
std::vector<Tensor> exp_average_buffers;
|
||||
std::vector<Tensor> exp_average_sq_buffers;
|
||||
std::vector<Tensor> max_exp_average_sq_buffers;
|
||||
|
||||
private:
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
friend class cereal::access;
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
Adam() : options(0) {}
|
||||
|
||||
std::vector<int64_t> step_buffers_;
|
||||
std::vector<Tensor> exp_average_buffers_;
|
||||
std::vector<Tensor> exp_average_sq_buffers_;
|
||||
std::vector<Tensor> max_exp_average_sq_buffers_;
|
||||
template <typename Self, typename Archive>
|
||||
static void serialize(Self& self, Archive& archive) {
|
||||
TORCH_OPTIM_SERIALIZE(step_buffers);
|
||||
TORCH_OPTIM_SERIALIZE(exp_average_buffers);
|
||||
TORCH_OPTIM_SERIALIZE(exp_average_sq_buffers);
|
||||
TORCH_OPTIM_SERIALIZE(max_exp_average_sq_buffers);
|
||||
}
|
||||
};
|
||||
} // namespace optim
|
||||
} // namespace torch
|
||||
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
CEREAL_REGISTER_TYPE(torch::optim::Adam);
|
||||
CEREAL_REGISTER_POLYMORPHIC_RELATION(
|
||||
torch::optim::Optimizer,
|
||||
torch::optim::Adam);
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
|
|
|
|||
|
|
@ -3,9 +3,8 @@
|
|||
#include <torch/arg.h>
|
||||
#include <torch/nn/module.h>
|
||||
#include <torch/optim/optimizer.h>
|
||||
#include <torch/serialization.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/optim/serialize.h>
|
||||
#include <torch/serialize/archive.h>
|
||||
|
||||
#include <deque>
|
||||
#include <functional>
|
||||
|
|
@ -38,39 +37,37 @@ class LBFGS : public LossClosureOptimizer {
|
|||
|
||||
LBFGSOptions options;
|
||||
|
||||
template <class Archive>
|
||||
void serialize(Archive& ar) {
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
ar(CEREAL_NVP(d));
|
||||
ar(CEREAL_NVP(t));
|
||||
ar(CEREAL_NVP(H_diag));
|
||||
ar(CEREAL_NVP(prev_flat_grad));
|
||||
ar(CEREAL_NVP(prev_loss));
|
||||
ar(CEREAL_NVP(old_dirs));
|
||||
ar(CEREAL_NVP(old_stps));
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
}
|
||||
|
||||
private:
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
friend class cereal::access;
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
LBFGS() : options(0) {}
|
||||
|
||||
Tensor gather_flat_grad();
|
||||
void add_grad(const torch::Scalar& step_size, const Tensor& update);
|
||||
void save(serialize::OutputArchive& archive) const override;
|
||||
void load(serialize::InputArchive& archive) override;
|
||||
|
||||
Tensor d{torch::empty({0})};
|
||||
Tensor H_diag{torch::empty({0})};
|
||||
Tensor prev_flat_grad{torch::empty({0})};
|
||||
torch::Scalar t{0};
|
||||
torch::Scalar prev_loss{0};
|
||||
Tensor t{torch::zeros(1)};
|
||||
Tensor prev_loss{torch::zeros(1)};
|
||||
std::vector<Tensor> ro;
|
||||
std::vector<Tensor> al;
|
||||
std::deque<Tensor> old_dirs;
|
||||
std::deque<Tensor> old_stps;
|
||||
int64_t func_evals{0};
|
||||
int64_t state_n_iter{0};
|
||||
|
||||
private:
|
||||
LBFGS() : options(0) {}
|
||||
|
||||
Tensor gather_flat_grad();
|
||||
void add_grad(const torch::Tensor& step_size, const Tensor& update);
|
||||
|
||||
template <typename Self, typename Archive>
|
||||
static void serialize(Self& self, Archive& archive) {
|
||||
archive("d", self.d, /*is_buffer=*/true);
|
||||
archive("t", self.t, /*is_buffer=*/true);
|
||||
archive("H_diag", self.H_diag, /*is_buffer=*/true);
|
||||
archive("prev_flat_grad", self.prev_flat_grad, /*is_buffer=*/true);
|
||||
archive("prev_loss", self.prev_loss, /*is_buffer=*/true);
|
||||
detail::serialize(archive, "old_dirs", self.old_dirs);
|
||||
detail::serialize(archive, "old_stps", self.old_stps);
|
||||
}
|
||||
};
|
||||
} // namespace optim
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -1,18 +1,22 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/autograd/generated/variable_factories.h>
|
||||
#include <torch/csrc/utils/variadic.h>
|
||||
#include <torch/nn/cursor.h>
|
||||
#include <torch/optim/serialize.h>
|
||||
#include <torch/serialize/archive.h>
|
||||
#include <torch/tensor.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <iterator>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace optim {
|
||||
namespace detail {
|
||||
|
||||
/// Base class for all optimizers, that does not yet define a `step()`
|
||||
/// mechanism. All it specifies is that optimizers must be supplied with a
|
||||
/// vector of parameters. It also defines certain methods that all optimizers
|
||||
|
|
@ -48,6 +52,9 @@ class OptimizerBase {
|
|||
/// Returns the number of parameters referenced by the optimizer.
|
||||
size_t size() const noexcept;
|
||||
|
||||
virtual void save(serialize::OutputArchive& archive) const;
|
||||
virtual void load(serialize::InputArchive& archive);
|
||||
|
||||
protected:
|
||||
OptimizerBase() = default;
|
||||
|
||||
|
|
|
|||
|
|
@ -3,15 +3,21 @@
|
|||
#include <torch/arg.h>
|
||||
#include <torch/nn/module.h>
|
||||
#include <torch/optim/optimizer.h>
|
||||
#include <torch/serialization.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/optim/serialize.h>
|
||||
#include <torch/tensor.h>
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace serialize {
|
||||
class OutputArchive;
|
||||
class InputArchive;
|
||||
} // namespace serialize
|
||||
} // namespace torch
|
||||
|
||||
namespace torch {
|
||||
namespace optim {
|
||||
|
||||
|
|
@ -38,31 +44,22 @@ class RMSprop : public Optimizer {
|
|||
|
||||
RMSpropOptions options;
|
||||
|
||||
template <class Archive>
|
||||
void serialize(Archive& ar) {
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
ar(CEREAL_NVP(square_average_buffers_));
|
||||
ar(CEREAL_NVP(momentum_buffers_));
|
||||
ar(CEREAL_NVP(grad_average_buffers_));
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
}
|
||||
void save(serialize::OutputArchive& archive) const override;
|
||||
void load(serialize::InputArchive& archive) override;
|
||||
|
||||
std::vector<Tensor> square_average_buffers;
|
||||
std::vector<Tensor> momentum_buffers;
|
||||
std::vector<Tensor> grad_average_buffers;
|
||||
|
||||
private:
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
friend class cereal::access;
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
RMSprop() : options(0) {}
|
||||
|
||||
std::vector<Tensor> square_average_buffers_;
|
||||
std::vector<Tensor> momentum_buffers_;
|
||||
std::vector<Tensor> grad_average_buffers_;
|
||||
template <typename Self, typename Archive>
|
||||
static void serialize(Self& self, Archive& archive) {
|
||||
TORCH_OPTIM_SERIALIZE(square_average_buffers);
|
||||
TORCH_OPTIM_SERIALIZE(momentum_buffers);
|
||||
TORCH_OPTIM_SERIALIZE(grad_average_buffers);
|
||||
}
|
||||
};
|
||||
} // namespace optim
|
||||
} // namespace torch
|
||||
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
CEREAL_REGISTER_TYPE(torch::optim::RMSprop);
|
||||
CEREAL_REGISTER_POLYMORPHIC_RELATION(
|
||||
torch::optim::Optimizer,
|
||||
torch::optim::RMSprop);
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
|
|
|
|||
67
torch/csrc/api/include/torch/optim/serialize.h
Normal file
67
torch/csrc/api/include/torch/optim/serialize.h
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/serialize/archive.h>
|
||||
#include <torch/tensor.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <deque>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace optim {
|
||||
namespace detail {
|
||||
|
||||
// Note: These functions are all called `serialize()` so they can be called
|
||||
// inside a template where the archive type is a template type and can thus be
|
||||
// passed such that the appropriate overload is selected.
|
||||
|
||||
/// Utility function to save a vector of step buffers.
|
||||
void serialize(
|
||||
serialize::OutputArchive& archive,
|
||||
const std::string& key,
|
||||
const std::vector<int64_t>& steps);
|
||||
|
||||
/// Utility function to load a vector of step buffers.
|
||||
void serialize(
|
||||
serialize::InputArchive& archive,
|
||||
const std::string& key,
|
||||
std::vector<int64_t>& steps);
|
||||
|
||||
/// Utility function to save a vector of buffers.
|
||||
template <typename BufferContainer>
|
||||
void serialize(
|
||||
serialize::OutputArchive& archive,
|
||||
const std::string& key,
|
||||
const BufferContainer& buffers) {
|
||||
archive.write(
|
||||
key + "/size", torch::tensor(static_cast<int64_t>(buffers.size())));
|
||||
for (size_t index = 0; index < buffers.size(); ++index) {
|
||||
archive.write(
|
||||
key + "/" + std::to_string(index), buffers[index], /*is_buffer=*/true);
|
||||
}
|
||||
}
|
||||
|
||||
/// Utility function to load a vector of buffers.
|
||||
template <typename BufferContainer>
|
||||
void serialize(
|
||||
serialize::InputArchive& archive,
|
||||
const std::string& key,
|
||||
BufferContainer& buffers) {
|
||||
torch::Tensor size_tensor;
|
||||
archive.read(key + "/size", size_tensor);
|
||||
const size_t size = size_tensor.toCLong();
|
||||
for (size_t index = 0; index < size; ++index) {
|
||||
buffers.emplace_back();
|
||||
archive.read(
|
||||
key + "/" + std::to_string(index), buffers.back(), /*is_buffer=*/true);
|
||||
}
|
||||
}
|
||||
|
||||
#define TORCH_OPTIM_SERIALIZE(name) \
|
||||
torch::optim::detail::serialize(archive, #name, self.name)
|
||||
|
||||
} // namespace detail
|
||||
} // namespace optim
|
||||
} // namespace torch
|
||||
|
|
@ -3,15 +3,19 @@
|
|||
#include <torch/arg.h>
|
||||
#include <torch/nn/module.h>
|
||||
#include <torch/optim/optimizer.h>
|
||||
#include <torch/serialization.h>
|
||||
#include <torch/tensor.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace serialize {
|
||||
class OutputArchive;
|
||||
class InputArchive;
|
||||
} // namespace serialize
|
||||
} // namespace torch
|
||||
|
||||
namespace torch {
|
||||
namespace optim {
|
||||
|
||||
|
|
@ -33,31 +37,18 @@ class SGD : public Optimizer {
|
|||
|
||||
void step() override;
|
||||
|
||||
template <class Archive>
|
||||
void serialize(Archive& ar) {
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
ar(CEREAL_NVP(momentum_buffers_));
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
}
|
||||
void save(serialize::OutputArchive& archive) const override;
|
||||
void load(serialize::InputArchive& archive) override;
|
||||
|
||||
SGDOptions options;
|
||||
|
||||
std::vector<Tensor> momentum_buffers;
|
||||
|
||||
private:
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
friend class cereal::access;
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
SGD() : options(0) {}
|
||||
|
||||
std::vector<Tensor> momentum_buffers_;
|
||||
/// Counts how often `step()` is called, for dampening.
|
||||
size_t iteration_{0};
|
||||
};
|
||||
} // namespace optim
|
||||
} // namespace torch
|
||||
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
CEREAL_REGISTER_TYPE(torch::optim::SGD);
|
||||
CEREAL_REGISTER_POLYMORPHIC_RELATION(
|
||||
torch::optim::Optimizer,
|
||||
torch::optim::SGD);
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
|
|
|
|||
|
|
@ -1,276 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
#include <cereal/access.hpp>
|
||||
#include <cereal/cereal.hpp>
|
||||
#include <cereal/types/polymorphic.hpp>
|
||||
|
||||
#include "cereal/archives/binary.hpp"
|
||||
|
||||
#include "cereal/types/string.hpp"
|
||||
#include "cereal/types/unordered_map.hpp"
|
||||
#include "cereal/types/vector.hpp"
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
|
||||
namespace torch {
|
||||
// Some convenience functions for saving and loading
|
||||
template <typename T>
|
||||
void save(std::ostream& stream, T const& obj) {
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
cereal::BinaryOutputArchive archive(stream);
|
||||
archive(*obj);
|
||||
#else
|
||||
AT_ERROR("PyTorch compiled without serialization support");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void load(std::istream& stream, T& obj) {
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
cereal::BinaryInputArchive archive(stream);
|
||||
archive(*obj);
|
||||
#else
|
||||
AT_ERROR("PyTorch compiled without serialization support");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void save(std::ostream& stream, T const* obj) {
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
cereal::BinaryOutputArchive archive(stream);
|
||||
archive(*obj);
|
||||
#else
|
||||
AT_ERROR("PyTorch compiled without serialization support");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void load(std::istream& stream, T* obj) {
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
cereal::BinaryInputArchive archive(stream);
|
||||
archive(*obj);
|
||||
#else
|
||||
AT_ERROR("PyTorch compiled without serialization support");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void save(std::string const& path, T const& obj) {
|
||||
std::ofstream os(path, std::ios::binary);
|
||||
torch::save(os, obj);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void load(std::string const& path, T& obj) {
|
||||
std::ifstream is(path, std::ios::binary);
|
||||
torch::load(is, obj);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
// We use our own hard-coded type<->id mapping so that serialization is robust
|
||||
// wrt changes in ATen; see e.g. https://git.io/vxd6R
|
||||
// The mapping is consistent with the ScalarType enum as of pytorch version
|
||||
// v0.1.11-7675-ge94c67e.
|
||||
inline int32_t scalarTypeId(torch::Dtype type) {
|
||||
switch (type) {
|
||||
case torch::Dtype::Byte:
|
||||
return 0;
|
||||
case torch::Dtype::Char:
|
||||
return 1;
|
||||
case torch::Dtype::Short:
|
||||
return 2;
|
||||
case torch::Dtype::Int:
|
||||
return 3;
|
||||
case torch::Dtype::Long:
|
||||
return 4;
|
||||
case torch::Dtype::Half:
|
||||
return 5;
|
||||
case torch::Dtype::Float:
|
||||
return 6;
|
||||
case torch::Dtype::Double:
|
||||
return 7;
|
||||
case torch::Dtype::Undefined:
|
||||
return 8;
|
||||
default:
|
||||
AT_ERROR("Unknown scalar type: ", static_cast<int>(type));
|
||||
}
|
||||
}
|
||||
|
||||
inline torch::Dtype scalarTypeFromId(int32_t id) {
|
||||
switch (id) {
|
||||
case 0:
|
||||
return torch::Dtype::Byte;
|
||||
case 1:
|
||||
return torch::Dtype::Char;
|
||||
case 2:
|
||||
return torch::Dtype::Short;
|
||||
case 3:
|
||||
return torch::Dtype::Int;
|
||||
case 4:
|
||||
return torch::Dtype::Long;
|
||||
case 5:
|
||||
return torch::Dtype::Half;
|
||||
case 6:
|
||||
return torch::Dtype::Float;
|
||||
case 7:
|
||||
return torch::Dtype::Double;
|
||||
case 8:
|
||||
return torch::Dtype::Undefined;
|
||||
default:
|
||||
AT_ERROR("Unknown scalar type id: ", id);
|
||||
}
|
||||
}
|
||||
|
||||
inline int32_t backendId(at::Backend backend) {
|
||||
switch (backend) {
|
||||
case at::Backend::CPU:
|
||||
return 0;
|
||||
case at::Backend::CUDA:
|
||||
return 1;
|
||||
case at::Backend::SparseCPU:
|
||||
return 2;
|
||||
case at::Backend::SparseCUDA:
|
||||
return 3;
|
||||
case at::Backend::Undefined:
|
||||
return 4;
|
||||
default:
|
||||
AT_ERROR("Unknown backend: ", static_cast<int>(backend));
|
||||
}
|
||||
}
|
||||
|
||||
inline at::Backend backendFromId(int32_t id) {
|
||||
switch (id) {
|
||||
case 0:
|
||||
return at::Backend::CPU;
|
||||
case 1:
|
||||
return at::Backend::CUDA;
|
||||
case 2:
|
||||
return at::Backend::SparseCPU;
|
||||
case 3:
|
||||
return at::Backend::SparseCUDA;
|
||||
case 4:
|
||||
return at::Backend::Undefined;
|
||||
default:
|
||||
AT_ERROR("Unknown backend id: ", id);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace torch
|
||||
|
||||
#if defined(TORCH_USE_CEREAL)
|
||||
namespace cereal {
|
||||
namespace agimpl {
|
||||
|
||||
template <class Archive>
|
||||
void saveBinary(Archive& archive, void const* data, size_t size) {
|
||||
// In general, there's no direct `saveBinary`-like method on archives
|
||||
std::vector<char> v(
|
||||
static_cast<char const*>(data), static_cast<char const*>(data) + size);
|
||||
archive(v);
|
||||
}
|
||||
template <>
|
||||
inline void saveBinary(
|
||||
BinaryOutputArchive& archive,
|
||||
void const* data,
|
||||
size_t size) {
|
||||
// Writes to output stream without extra copy
|
||||
archive.saveBinary(data, size);
|
||||
}
|
||||
|
||||
template <class Archive>
|
||||
void loadBinary(Archive& archive, void* data, size_t size) {
|
||||
// In general, there's no direct `loadBinary`-like method on archives
|
||||
std::vector<char> v(size);
|
||||
archive(v);
|
||||
std::memcpy(data, v.data(), size);
|
||||
}
|
||||
template <>
|
||||
inline void loadBinary(BinaryInputArchive& archive, void* data, size_t size) {
|
||||
// Read from input stream without extra copy
|
||||
archive.loadBinary(data, size);
|
||||
}
|
||||
|
||||
} // namespace agimpl
|
||||
|
||||
// Gradients will not be saved for variables
|
||||
template <class Archive>
|
||||
void save(Archive& archive, const torch::Tensor& tensor) {
|
||||
if (!tensor.defined()) {
|
||||
int32_t typeId = ::torch::detail::scalarTypeId(torch::Dtype::Undefined);
|
||||
archive(CEREAL_NVP(typeId));
|
||||
return;
|
||||
} else {
|
||||
int32_t typeId = ::torch::detail::scalarTypeId(tensor.dtype());
|
||||
archive(CEREAL_NVP(typeId));
|
||||
}
|
||||
auto sizes = std::vector<int64_t>();
|
||||
auto buf = std::vector<uint8_t>();
|
||||
for (auto s : tensor.sizes()) {
|
||||
sizes.push_back(s);
|
||||
}
|
||||
auto contig = tensor.cpu().contiguous();
|
||||
int32_t backend = ::torch::detail::backendId(tensor.type().backend());
|
||||
|
||||
archive(CEREAL_NVP(backend), CEREAL_NVP(sizes));
|
||||
agimpl::saveBinary(
|
||||
archive,
|
||||
contig.data_ptr(),
|
||||
tensor.numel() * tensor.type().elementSizeInBytes());
|
||||
}
|
||||
|
||||
/**
|
||||
* We follow these rules for loading:
|
||||
* 1. If tensor is defined, and the same ScalarType as the saved tensor,
|
||||
* then we simply copy the data into the tensor, with resizing.
|
||||
* 2. Otherwise, overwrite the provided tensor with the right type and backend
|
||||
**/
|
||||
template <class Archive>
|
||||
void load(Archive& archive, torch::Tensor& tensor) {
|
||||
torch::NoGradGuard guard;
|
||||
torch::Dtype type;
|
||||
int32_t typeId;
|
||||
archive(CEREAL_NVP(typeId));
|
||||
type = ::torch::detail::scalarTypeFromId(typeId);
|
||||
if (type == torch::Dtype::Undefined) {
|
||||
tensor = torch::Tensor();
|
||||
return;
|
||||
}
|
||||
|
||||
int32_t backendId;
|
||||
auto sizes = std::vector<int64_t>();
|
||||
auto buf = std::vector<uint8_t>();
|
||||
archive(CEREAL_NVP(backendId), CEREAL_NVP(sizes));
|
||||
|
||||
at::Backend backend = ::torch::detail::backendFromId(backendId);
|
||||
if (!tensor.defined() || tensor.dtype() != type) {
|
||||
tensor = torch::empty({}, at::TensorOptions(backend).dtype(type));
|
||||
}
|
||||
const auto required_grad = tensor.requires_grad();
|
||||
tensor.set_requires_grad(false);
|
||||
tensor.resize_(sizes);
|
||||
tensor.set_requires_grad(required_grad);
|
||||
|
||||
if (tensor.type().is_cuda()) {
|
||||
// should actually use cudamemcpy probably
|
||||
auto cputensor = torch::empty(sizes, tensor.dtype());
|
||||
agimpl::loadBinary(
|
||||
archive,
|
||||
cputensor.data_ptr(),
|
||||
cputensor.numel() * cputensor.type().elementSizeInBytes());
|
||||
tensor.copy_(cputensor);
|
||||
} else {
|
||||
agimpl::loadBinary(
|
||||
archive,
|
||||
tensor.data_ptr(),
|
||||
tensor.numel() * tensor.type().elementSizeInBytes());
|
||||
}
|
||||
}
|
||||
} // namespace cereal
|
||||
#endif // defined(TORCH_USE_CEREAL)
|
||||
49
torch/csrc/api/include/torch/serialize.h
Normal file
49
torch/csrc/api/include/torch/serialize.h
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/nn/module.h>
|
||||
#include <torch/nn/pimpl.h>
|
||||
#include <torch/serialize/archive.h>
|
||||
#include <torch/tensor.h>
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
namespace torch {
|
||||
namespace optim {
|
||||
class Optimizer;
|
||||
} // namespace optim
|
||||
} // namespace torch
|
||||
|
||||
namespace torch {
|
||||
namespace serialize {
|
||||
template <typename ModuleType>
|
||||
void save(const nn::ModuleHolder<ModuleType>& module, OutputArchive& archive) {
|
||||
module->save(archive);
|
||||
}
|
||||
|
||||
template <typename ModuleType>
|
||||
void load(nn::ModuleHolder<ModuleType>& module, InputArchive& archive) {
|
||||
module->load(archive);
|
||||
}
|
||||
|
||||
void save(const Tensor& tensor, OutputArchive& archive);
|
||||
|
||||
void save(const optim::Optimizer& optimizer, OutputArchive& archive);
|
||||
void load(optim::Optimizer& optimizer, InputArchive& archive);
|
||||
} // namespace serialize
|
||||
|
||||
template <typename T>
|
||||
void save(const T& value, const std::string& filename) {
|
||||
serialize::OutputArchive archive;
|
||||
serialize::save(value, archive);
|
||||
serialize::save_to_file(archive, filename);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void load(T& value, const std::string& filename) {
|
||||
serialize::InputArchive archive = serialize::load_from_file(filename);
|
||||
serialize::load(value, archive);
|
||||
}
|
||||
|
||||
Tensor load(const std::string& filename);
|
||||
} // namespace torch
|
||||
4
torch/csrc/api/include/torch/serialize/archive.h
Normal file
4
torch/csrc/api/include/torch/serialize/archive.h
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/serialize/input-archive.h>
|
||||
#include <torch/serialize/output-archive.h>
|
||||
41
torch/csrc/api/include/torch/serialize/input-archive.h
Normal file
41
torch/csrc/api/include/torch/serialize/input-archive.h
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/tensor.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
struct Module;
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
namespace torch {
|
||||
namespace serialize {
|
||||
class InputArchive final {
|
||||
public:
|
||||
InputArchive();
|
||||
|
||||
void read(const std::string& key, Tensor& tensor, bool is_buffer = false);
|
||||
void read(const std::string& key, InputArchive& archive);
|
||||
|
||||
template <typename... Ts>
|
||||
void operator()(Ts&&... ts) {
|
||||
read(std::forward<Ts>(ts)...);
|
||||
}
|
||||
|
||||
private:
|
||||
friend InputArchive load_from_file(const std::string& filename);
|
||||
|
||||
InputArchive(std::shared_ptr<jit::script::Module> module);
|
||||
|
||||
std::shared_ptr<jit::script::Module> module_;
|
||||
};
|
||||
|
||||
InputArchive load_from_file(const std::string& filename);
|
||||
} // namespace serialize
|
||||
} // namespace torch
|
||||
42
torch/csrc/api/include/torch/serialize/output-archive.h
Normal file
42
torch/csrc/api/include/torch/serialize/output-archive.h
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/tensor.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
struct Module;
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
namespace torch {
|
||||
namespace serialize {
|
||||
class OutputArchive final {
|
||||
public:
|
||||
OutputArchive();
|
||||
|
||||
void write(
|
||||
const std::string& key,
|
||||
const Tensor& tensor,
|
||||
bool is_buffer = false);
|
||||
void write(const std::string& key, OutputArchive& nested_archive);
|
||||
|
||||
template <typename... Ts>
|
||||
void operator()(Ts&&... ts) {
|
||||
write(std::forward<Ts>(ts)...);
|
||||
}
|
||||
|
||||
private:
|
||||
friend void save_to_file(const OutputArchive&, const std::string&);
|
||||
|
||||
std::shared_ptr<jit::script::Module> module_;
|
||||
};
|
||||
|
||||
void save_to_file(const OutputArchive& archive, const std::string& filename);
|
||||
} // namespace serialize
|
||||
} // namespace torch
|
||||
|
|
@ -3,6 +3,6 @@
|
|||
#include <torch/cuda.h>
|
||||
#include <torch/nn.h>
|
||||
#include <torch/optim.h>
|
||||
#include <torch/serialization.h>
|
||||
#include <torch/serialize.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/utils.h>
|
||||
|
|
|
|||
|
|
@ -119,6 +119,39 @@ void Module::zero_grad() {
|
|||
}
|
||||
}
|
||||
|
||||
void Module::save(serialize::OutputArchive& archive) const {
|
||||
for (const auto& parameter : parameters_) {
|
||||
archive.write(parameter.key, parameter.value);
|
||||
}
|
||||
for (const auto& buffer : buffers_) {
|
||||
archive.write(buffer.key, buffer.value, /*is_buffer=*/true);
|
||||
}
|
||||
for (const auto& child : children_) {
|
||||
serialize::OutputArchive child_archive;
|
||||
child.value->save(child_archive);
|
||||
archive.write(child.key, child_archive);
|
||||
}
|
||||
}
|
||||
|
||||
void Module::load(serialize::InputArchive& archive) {
|
||||
for (auto& parameter : parameters_) {
|
||||
archive.read(parameter.key, parameter.value);
|
||||
}
|
||||
for (auto& buffer : buffers_) {
|
||||
archive.read(buffer.key, buffer.value, /*is_buffer=*/true);
|
||||
}
|
||||
for (const auto& child : children_) {
|
||||
// Modules that have no state at all (parameters or buffers) are currently
|
||||
// not stored in Protobuf at all, so we can just skip them.
|
||||
if (!child.value->parameters_.is_empty() ||
|
||||
!child.value->buffers_.is_empty()) {
|
||||
serialize::InputArchive child_archive;
|
||||
archive.read(child.key, child_archive);
|
||||
child.value->load(child_archive);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Tensor& Module::register_parameter(
|
||||
std::string name,
|
||||
Tensor tensor,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#include <torch/optim/adagrad.h>
|
||||
|
||||
#include <torch/csrc/autograd/variable.h>
|
||||
#include <torch/serialize/archive.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
|
@ -26,17 +27,25 @@ void Adagrad::step() {
|
|||
p.grad() = p.grad() + options.weight_decay_ * p;
|
||||
}
|
||||
|
||||
buffer_at(step_, i) += 1.0;
|
||||
buffer_at(step_buffers, i) += 1.0;
|
||||
const auto clr = options.learning_rate_ /
|
||||
(1.0 + (buffer_at(step_, i) - 1.0) * options.lr_decay_);
|
||||
(1.0 + (buffer_at(step_buffers, i) - 1.0) * options.lr_decay_);
|
||||
|
||||
auto& sum = buffer_at(sum_, i);
|
||||
auto& sum = buffer_at(sum_buffers, i);
|
||||
sum.addcmul_(p.grad(), p.grad(), 1.0);
|
||||
const auto std = buffer_at(sum_, i).sqrt().add_(1e-10);
|
||||
const auto std = buffer_at(sum_buffers, i).sqrt().add_(1e-10);
|
||||
|
||||
NoGradGuard guard;
|
||||
p.addcdiv_(p.grad(), std, -clr);
|
||||
}
|
||||
}
|
||||
|
||||
void Adagrad::save(serialize::OutputArchive& archive) const {
|
||||
serialize(*this, archive);
|
||||
}
|
||||
|
||||
void Adagrad::load(serialize::InputArchive& archive) {
|
||||
serialize(*this, archive);
|
||||
}
|
||||
} // namespace optim
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
#include <torch/csrc/autograd/variable.h>
|
||||
#include <torch/nn/module.h>
|
||||
#include <torch/serialize/archive.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
|
@ -11,7 +12,6 @@
|
|||
|
||||
namespace torch {
|
||||
namespace optim {
|
||||
|
||||
AdamOptions::AdamOptions(double learning_rate)
|
||||
: learning_rate_(learning_rate) {}
|
||||
|
||||
|
|
@ -26,10 +26,10 @@ void Adam::step() {
|
|||
p.grad() = p.grad() + options.weight_decay_ * p;
|
||||
}
|
||||
|
||||
auto& exp_average = buffer_at(exp_average_buffers_, i);
|
||||
auto& exp_average_sq = buffer_at(exp_average_sq_buffers_, i);
|
||||
auto& exp_average = buffer_at(exp_average_buffers, i);
|
||||
auto& exp_average_sq = buffer_at(exp_average_sq_buffers, i);
|
||||
|
||||
buffer_at(step_buffers_, i) += 1;
|
||||
buffer_at(step_buffers, i) += 1;
|
||||
|
||||
exp_average.mul_(options.beta1_).add_(p.grad(), 1 - options.beta1_);
|
||||
exp_average_sq.mul_(options.beta2_)
|
||||
|
|
@ -37,15 +37,15 @@ void Adam::step() {
|
|||
|
||||
Tensor denom = exp_average_sq;
|
||||
if (options.amsgrad_) {
|
||||
auto& max_exp_average_sq = buffer_at(max_exp_average_sq_buffers_, i);
|
||||
auto& max_exp_average_sq = buffer_at(max_exp_average_sq_buffers, i);
|
||||
max_exp_average_sq = torch::max(max_exp_average_sq, exp_average_sq);
|
||||
denom = max_exp_average_sq;
|
||||
}
|
||||
|
||||
const auto bias_correction1 =
|
||||
1 - std::pow(options.beta1_, buffer_at(step_buffers_, i));
|
||||
1 - std::pow(options.beta1_, buffer_at(step_buffers, i));
|
||||
const auto bias_correction2 =
|
||||
1 - std::pow(options.beta2_, buffer_at(step_buffers_, i));
|
||||
1 - std::pow(options.beta2_, buffer_at(step_buffers, i));
|
||||
const auto step_size =
|
||||
options.learning_rate_ * std::sqrt(bias_correction2) / bias_correction1;
|
||||
|
||||
|
|
@ -54,5 +54,12 @@ void Adam::step() {
|
|||
}
|
||||
}
|
||||
|
||||
void Adam::save(serialize::OutputArchive& archive) const {
|
||||
serialize(*this, archive);
|
||||
}
|
||||
|
||||
void Adam::load(serialize::InputArchive& archive) {
|
||||
serialize(*this, archive);
|
||||
}
|
||||
} // namespace optim
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
#include <torch/csrc/autograd/generated/variable_factories.h>
|
||||
#include <torch/csrc/autograd/variable.h>
|
||||
#include <torch/serialize/archive.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
|
|
@ -23,12 +24,14 @@ Tensor LBFGS::gather_flat_grad() {
|
|||
return at::cat(views);
|
||||
}
|
||||
|
||||
void LBFGS::add_grad(const torch::Scalar& step_size, const Tensor& update) {
|
||||
void LBFGS::add_grad(const torch::Tensor& step_size, const Tensor& update) {
|
||||
int64_t offset = 0;
|
||||
for (auto& parameter : parameters_) {
|
||||
int64_t numel = parameter.numel();
|
||||
Tensor& pd = autograd::Variable(parameter).data();
|
||||
pd.add_(update.slice(0, offset, offset + numel, 1).view_as(pd), step_size);
|
||||
pd.add_(
|
||||
update.slice(0, offset, offset + numel, 1).view_as(pd),
|
||||
step_size.toCFloat());
|
||||
offset += numel;
|
||||
}
|
||||
}
|
||||
|
|
@ -109,9 +112,9 @@ torch::Tensor LBFGS::step(LossClosure closure) {
|
|||
|
||||
// reset initial guess for step size
|
||||
if (n_iter == 1) {
|
||||
t = at::_local_scalar(at::min(ONE, ONE / abs_grad_sum) * options.learning_rate_);
|
||||
t = torch::min(ONE, ONE / abs_grad_sum) * options.learning_rate_;
|
||||
} else {
|
||||
t = options.learning_rate_;
|
||||
t = at::tensor(options.learning_rate_, torch::kFloat32);
|
||||
}
|
||||
|
||||
Tensor gtd = flat_grad.dot(d);
|
||||
|
|
@ -141,17 +144,23 @@ torch::Tensor LBFGS::step(LossClosure closure) {
|
|||
break;
|
||||
} else if (gtd.toCFloat() > -options.tolerance_grad_) {
|
||||
break;
|
||||
} else if (
|
||||
d.mul(t).abs_().sum().toCFloat() <=
|
||||
options.tolerance_change_) {
|
||||
} else if (d.mul(t).abs_().sum().toCFloat() <= options.tolerance_change_) {
|
||||
break;
|
||||
} else if (
|
||||
std::abs(loss.toCFloat() - prev_loss.toFloat()) <
|
||||
std::abs(loss.toCFloat() - prev_loss.toCFloat()) <
|
||||
options.tolerance_change_) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return orig_loss;
|
||||
}
|
||||
|
||||
void LBFGS::save(serialize::OutputArchive& archive) const {
|
||||
serialize(*this, archive);
|
||||
}
|
||||
|
||||
void LBFGS::load(serialize::InputArchive& archive) {
|
||||
serialize(*this, archive);
|
||||
}
|
||||
} // namespace optim
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -3,13 +3,13 @@
|
|||
#include <torch/nn/cursor.h>
|
||||
#include <torch/tensor.h>
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace optim {
|
||||
namespace detail {
|
||||
|
||||
OptimizerBase::OptimizerBase(std::vector<Tensor> parameters)
|
||||
: parameters_(std::move(parameters)) {}
|
||||
|
||||
|
|
@ -47,6 +47,9 @@ std::vector<Tensor>& OptimizerBase::parameters() noexcept {
|
|||
size_t OptimizerBase::size() const noexcept {
|
||||
return parameters_.size();
|
||||
}
|
||||
|
||||
void OptimizerBase::save(serialize::OutputArchive& archive) const {}
|
||||
void OptimizerBase::load(serialize::InputArchive& archive) {}
|
||||
} // namespace detail
|
||||
} // namespace optim
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#include <torch/optim/rmsprop.h>
|
||||
|
||||
#include <torch/csrc/autograd/variable.h>
|
||||
#include <torch/serialize/archive.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
|
@ -26,13 +27,13 @@ void RMSprop::step() {
|
|||
p.grad() = p.grad() + options.weight_decay_ * p;
|
||||
}
|
||||
|
||||
auto square_average = buffer_at(square_average_buffers_, i);
|
||||
auto square_average = buffer_at(square_average_buffers, i);
|
||||
square_average.mul_(options.alpha_)
|
||||
.addcmul_(p.grad(), p.grad(), 1.0 - options.alpha_);
|
||||
|
||||
Tensor average;
|
||||
if (options.centered_ > 0) {
|
||||
auto& grad_average = buffer_at(grad_average_buffers_, i);
|
||||
auto& grad_average = buffer_at(grad_average_buffers, i);
|
||||
grad_average.mul_(options.alpha_).add_(p.grad(), 1.0 - options.alpha_);
|
||||
average = square_average.addcmul(grad_average, grad_average, -1.0)
|
||||
.sqrt()
|
||||
|
|
@ -43,7 +44,7 @@ void RMSprop::step() {
|
|||
|
||||
NoGradGuard guard;
|
||||
if (options.momentum_ > 0) {
|
||||
auto& momentum = buffer_at(momentum_buffers_, i);
|
||||
auto& momentum = buffer_at(momentum_buffers, i);
|
||||
momentum.mul_(options.momentum_).addcdiv_(p.grad(), average);
|
||||
p.add_(momentum, -options.learning_rate_);
|
||||
} else {
|
||||
|
|
@ -51,5 +52,13 @@ void RMSprop::step() {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
void RMSprop::save(serialize::OutputArchive& archive) const {
|
||||
serialize(*this, archive);
|
||||
}
|
||||
|
||||
void RMSprop::load(serialize::InputArchive& archive) {
|
||||
serialize(*this, archive);
|
||||
}
|
||||
} // namespace optim
|
||||
} // namespace torch
|
||||
|
|
|
|||
39
torch/csrc/api/src/optim/serialize.cpp
Normal file
39
torch/csrc/api/src/optim/serialize.cpp
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
#include <torch/optim/serialize.h>
|
||||
|
||||
#include <torch/serialize/archive.h>
|
||||
#include <torch/tensor.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <deque>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace optim {
|
||||
namespace detail {
|
||||
void serialize(
|
||||
serialize::OutputArchive& archive,
|
||||
const std::string& key,
|
||||
const std::vector<int64_t>& steps) {
|
||||
std::vector<torch::Tensor> tensors;
|
||||
for (const auto& step : steps) {
|
||||
tensors.push_back(torch::tensor(static_cast<int64_t>(step)));
|
||||
}
|
||||
serialize(archive, key, tensors);
|
||||
}
|
||||
|
||||
void serialize(
|
||||
serialize::InputArchive& archive,
|
||||
const std::string& key,
|
||||
std::vector<int64_t>& steps) {
|
||||
std::vector<torch::Tensor> tensors;
|
||||
serialize(archive, key, tensors);
|
||||
steps.clear();
|
||||
for (const auto& step : tensors) {
|
||||
steps.push_back(step.toCLong());
|
||||
}
|
||||
}
|
||||
} // namespace detail
|
||||
} // namespace optim
|
||||
} // namespace torch
|
||||
|
|
@ -1,6 +1,10 @@
|
|||
#include <torch/optim/sgd.h>
|
||||
|
||||
#include <torch/csrc/autograd/variable.h>
|
||||
#include <torch/nn/pimpl.h>
|
||||
#include <torch/optim/optimizer.h>
|
||||
#include <torch/optim/serialize.h>
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
|
@ -22,7 +26,7 @@ void SGD::step() {
|
|||
auto update = options.learning_rate_ * p.grad();
|
||||
if (options.momentum_ != 0) {
|
||||
const auto dampening = iteration_ == 0 ? 1 : 1 - options.dampening_;
|
||||
auto& momentum = buffer_at(momentum_buffers_, i);
|
||||
auto& momentum = buffer_at(momentum_buffers, i);
|
||||
momentum = (options.momentum_ * momentum) + (dampening * update);
|
||||
if (options.nesterov_) {
|
||||
// See github.com/lisa-lab/pylearn2/pull/136#issuecomment-10381617
|
||||
|
|
@ -42,5 +46,13 @@ void SGD::step() {
|
|||
}
|
||||
iteration_ += 1;
|
||||
}
|
||||
|
||||
void SGD::save(serialize::OutputArchive& archive) const {
|
||||
detail::serialize(archive, "momentum_buffers", momentum_buffers);
|
||||
}
|
||||
|
||||
void SGD::load(serialize::InputArchive& archive) {
|
||||
detail::serialize(archive, "momentum_buffers", momentum_buffers);
|
||||
}
|
||||
} // namespace optim
|
||||
} // namespace torch
|
||||
|
|
|
|||
31
torch/csrc/api/src/serialize.cpp
Normal file
31
torch/csrc/api/src/serialize.cpp
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
#include <torch/serialize.h>
|
||||
|
||||
#include <torch/optim/optimizer.h>
|
||||
#include <torch/serialize/archive.h>
|
||||
#include <torch/tensor.h>
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
namespace torch {
|
||||
namespace serialize {
|
||||
void save(const Tensor& tensor, OutputArchive& archive) {
|
||||
archive.write("0", tensor);
|
||||
}
|
||||
|
||||
void save(const optim::Optimizer& optimizer, OutputArchive& archive) {
|
||||
optimizer.save(archive);
|
||||
}
|
||||
|
||||
void load(optim::Optimizer& optimizer, InputArchive& archive) {
|
||||
optimizer.load(archive);
|
||||
}
|
||||
} // namespace serialize
|
||||
|
||||
Tensor load(const std::string& filename) {
|
||||
serialize::InputArchive archive = serialize::load_from_file(filename);
|
||||
Tensor tensor;
|
||||
archive.read("0", tensor);
|
||||
return tensor;
|
||||
}
|
||||
} // namespace torch
|
||||
57
torch/csrc/api/src/serialize/input-archive.cpp
Normal file
57
torch/csrc/api/src/serialize/input-archive.cpp
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
#include <torch/serialize/input-archive.h>
|
||||
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <torch/csrc/jit/import.h>
|
||||
#include <torch/csrc/jit/script/module.h>
|
||||
|
||||
#include <ATen/Error.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
namespace torch {
|
||||
namespace serialize {
|
||||
|
||||
InputArchive::InputArchive()
|
||||
: module_(std::make_shared<jit::script::Module>()) {}
|
||||
|
||||
void InputArchive::read(
|
||||
const std::string& key,
|
||||
Tensor& tensor,
|
||||
bool is_buffer) {
|
||||
auto* read_tensor = module_->find_parameter(key);
|
||||
AT_CHECK(read_tensor != nullptr, "No such serialized tensor '", key, "'");
|
||||
// clang-format off
|
||||
AT_CHECK(
|
||||
read_tensor->is_buffer == is_buffer,
|
||||
"Expected deserialized tensor for key '", key,
|
||||
"' to ", is_buffer ? "not " : "", "be a buffer, but it was not");
|
||||
// clang-format on
|
||||
if (tensor.defined()) {
|
||||
torch::NoGradGuard guard;
|
||||
tensor.set_(*read_tensor->slot());
|
||||
} else {
|
||||
tensor = std::move(*read_tensor->slot());
|
||||
}
|
||||
}
|
||||
|
||||
void InputArchive::read(const std::string& key, InputArchive& archive) {
|
||||
if (auto* named_module = module_->find_module(key)) {
|
||||
AT_ASSERT(named_module->module != nullptr);
|
||||
archive.module_ = std::move(named_module->module);
|
||||
} else {
|
||||
AT_ERROR("No such serialized submodule: '", key, "'");
|
||||
}
|
||||
}
|
||||
|
||||
InputArchive::InputArchive(std::shared_ptr<jit::script::Module> module)
|
||||
: module_(std::move(module)) {}
|
||||
|
||||
InputArchive load_from_file(const std::string& filename) {
|
||||
return InputArchive(torch::jit::load(filename));
|
||||
}
|
||||
} // namespace serialize
|
||||
} // namespace torch
|
||||
34
torch/csrc/api/src/serialize/output-archive.cpp
Normal file
34
torch/csrc/api/src/serialize/output-archive.cpp
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
#include <torch/serialize/output-archive.h>
|
||||
|
||||
#include <torch/tensor.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <torch/csrc/jit/export.h>
|
||||
#include <torch/csrc/jit/script/module.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
namespace torch {
|
||||
namespace serialize {
|
||||
OutputArchive::OutputArchive()
|
||||
: module_(std::make_shared<jit::script::Module>()) {}
|
||||
|
||||
void OutputArchive::write(
|
||||
const std::string& key,
|
||||
const Tensor& tensor,
|
||||
bool is_buffer) {
|
||||
module_->register_parameter(key, tensor, is_buffer);
|
||||
}
|
||||
|
||||
void OutputArchive::write(
|
||||
const std::string& key,
|
||||
OutputArchive& nested_archive) {
|
||||
module_->register_module(key, nested_archive.module_);
|
||||
}
|
||||
|
||||
void save_to_file(const OutputArchive& archive, const std::string& filename) {
|
||||
jit::ExportModule(*archive.module_, filename);
|
||||
}
|
||||
} // namespace serialize
|
||||
} // namespace torch
|
||||
|
|
@ -334,6 +334,7 @@ at::Tensor ModuleDecoder::buildTensorCommon(
|
|||
std::pair<std::shared_ptr<script::Module>, std::string> ModuleDecoder::parseFullName(
|
||||
ModuleLookup module_lookup,
|
||||
const std::string fullname) {
|
||||
AT_ASSERT(!fullname.empty());
|
||||
std::vector<std::string> vec;
|
||||
std::stringstream ss(fullname);
|
||||
std::string name;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user