mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fixes #95061, #95062 Add Flatbuffer verification before parsing to avoid crashing on malformed modules. Flatbuffers doesn't perform boundary checks at runtime for the sake of performance, so when parsing untrusted modules it is highly recommended to verify overall buffer integrity. This bug can be triggered both by C++ (`torch::jit::load`, `torch::jitload_jit_module_from_file`) and Python API (`torch.jit.load`, `torch.jit.jit_module_from_flatbuffer`). Crash files to reproduce: [crash-1feb368861083e3d242e5c3fcb1090869f4819c4.txt](https://github.com/pytorch/pytorch/files/10795267/crash-1feb368861083e3d242e5c3fcb1090869f4819c4.txt) [crash-7e8ffd314223be96b43ca246d3d3481702869455.txt](https://github.com/pytorch/pytorch/files/10795268/crash-7e8ffd314223be96b43ca246d3d3481702869455.txt) [crash-ad4d7c6183af8f34fe1cb5c8133315c6389c409f.txt](https://github.com/pytorch/pytorch/files/10795279/crash-ad4d7c6183af8f34fe1cb5c8133315c6389c409f.txt) Pull Request resolved: https://github.com/pytorch/pytorch/pull/95221 Approved by: https://github.com/qihqi, https://github.com/davidberard98
393 lines
12 KiB
C++
393 lines
12 KiB
C++
#include <test/cpp/jit/test_utils.h>
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
#include <c10/core/TensorOptions.h>
|
|
#include <torch/csrc/autograd/generated/variable_factories.h>
|
|
#include <torch/csrc/jit/api/module.h>
|
|
#include <torch/csrc/jit/mobile/import.h>
|
|
#include <torch/csrc/jit/mobile/import_data.h>
|
|
#include <torch/csrc/jit/mobile/module.h>
|
|
#include <torch/csrc/jit/mobile/train/export_data.h>
|
|
#include <torch/csrc/jit/mobile/train/optim/sgd.h>
|
|
#include <torch/csrc/jit/mobile/train/random.h>
|
|
#include <torch/csrc/jit/mobile/train/sequential.h>
|
|
#include <torch/csrc/jit/serialization/import.h>
|
|
#include <torch/data/dataloader.h>
|
|
#include <torch/torch.h>
|
|
|
|
// Tests go in torch::jit
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
TEST(LiteTrainerTest, Params) {
|
|
Module m("m");
|
|
m.register_parameter("foo", torch::ones({1}, at::requires_grad()), false);
|
|
m.define(R"(
|
|
def forward(self, x):
|
|
b = 1.0
|
|
return self.foo * x + b
|
|
)");
|
|
double learning_rate = 0.1, momentum = 0.1;
|
|
int n_epoc = 10;
|
|
// init: y = x + 1;
|
|
// target: y = 2 x + 1
|
|
std::vector<std::pair<Tensor, Tensor>> trainData{
|
|
{1 * torch::ones({1}), 3 * torch::ones({1})},
|
|
};
|
|
// Reference: Full jit
|
|
std::stringstream ms;
|
|
m.save(ms);
|
|
auto mm = load(ms);
|
|
// mm.train();
|
|
std::vector<::at::Tensor> parameters;
|
|
for (auto parameter : mm.parameters()) {
|
|
parameters.emplace_back(parameter);
|
|
}
|
|
::torch::optim::SGD optimizer(
|
|
parameters, ::torch::optim::SGDOptions(learning_rate).momentum(momentum));
|
|
for (int epoc = 0; epoc < n_epoc; ++epoc) {
|
|
for (auto& data : trainData) {
|
|
auto source = data.first, targets = data.second;
|
|
optimizer.zero_grad();
|
|
std::vector<IValue> train_inputs{source};
|
|
auto output = mm.forward(train_inputs).toTensor();
|
|
auto loss = ::torch::l1_loss(output, targets);
|
|
loss.backward();
|
|
optimizer.step();
|
|
}
|
|
}
|
|
std::stringstream ss;
|
|
m._save_for_mobile(ss);
|
|
mobile::Module bc = _load_for_mobile(ss);
|
|
std::vector<::at::Tensor> bc_parameters = bc.parameters();
|
|
::torch::optim::SGD bc_optimizer(
|
|
bc_parameters,
|
|
::torch::optim::SGDOptions(learning_rate).momentum(momentum));
|
|
for (int epoc = 0; epoc < n_epoc; ++epoc) {
|
|
for (auto& data : trainData) {
|
|
auto source = data.first, targets = data.second;
|
|
bc_optimizer.zero_grad();
|
|
std::vector<IValue> train_inputs{source};
|
|
auto output = bc.forward(train_inputs).toTensor();
|
|
auto loss = ::torch::l1_loss(output, targets);
|
|
loss.backward();
|
|
bc_optimizer.step();
|
|
}
|
|
}
|
|
AT_ASSERT(parameters[0].item<float>() == bc_parameters[0].item<float>());
|
|
}
|
|
|
|
// TODO Renable these tests after parameters are correctly loaded on mobile
|
|
/*
|
|
TEST(MobileTest, NamedParameters) {
|
|
Module m("m");
|
|
m.register_parameter("foo", torch::ones({}), false);
|
|
m.define(R"(
|
|
def add_it(self, x):
|
|
b = 4
|
|
return self.foo + x + b
|
|
)");
|
|
Module child("m2");
|
|
child.register_parameter("foo", 4 * torch::ones({}), false);
|
|
child.register_parameter("bar", 4 * torch::ones({}), false);
|
|
m.register_module("child1", child);
|
|
m.register_module("child2", child.clone());
|
|
std::stringstream ss;
|
|
m._save_for_mobile(ss);
|
|
mobile::Module bc = _load_for_mobile(ss);
|
|
|
|
auto full_params = m.named_parameters();
|
|
auto mobile_params = bc.named_parameters();
|
|
AT_ASSERT(full_params.size() == mobile_params.size());
|
|
for (const auto& e : full_params) {
|
|
AT_ASSERT(e.value.item().toInt() ==
|
|
mobile_params[e.name].item().toInt());
|
|
}
|
|
}
|
|
|
|
TEST(MobileTest, SaveLoadParameters) {
|
|
Module m("m");
|
|
m.register_parameter("foo", torch::ones({}), false);
|
|
m.define(R"(
|
|
def add_it(self, x):
|
|
b = 4
|
|
return self.foo + x + b
|
|
)");
|
|
Module child("m2");
|
|
child.register_parameter("foo", 4 * torch::ones({}), false);
|
|
child.register_parameter("bar", 3 * torch::ones({}), false);
|
|
m.register_module("child1", child);
|
|
m.register_module("child2", child.clone());
|
|
auto full_params = m.named_parameters();
|
|
std::stringstream ss;
|
|
std::stringstream ss_data;
|
|
m._save_for_mobile(ss);
|
|
|
|
// load mobile module, save mobile named parameters
|
|
mobile::Module bc = _load_for_mobile(ss);
|
|
_save_parameters(bc.named_parameters(), ss_data);
|
|
|
|
// load back the named parameters, compare to full-jit Module's
|
|
auto mobile_params = _load_parameters(ss_data);
|
|
AT_ASSERT(full_params.size() == mobile_params.size());
|
|
for (const auto& e : full_params) {
|
|
AT_ASSERT(e.value.item<int>() == mobile_params[e.name].item<int>());
|
|
}
|
|
}
|
|
*/
|
|
|
|
TEST(MobileTest, SaveLoadParametersEmpty) {
|
|
Module m("m");
|
|
m.define(R"(
|
|
def add_it(self, x):
|
|
b = 4
|
|
return x + b
|
|
)");
|
|
Module child("m2");
|
|
m.register_module("child1", child);
|
|
m.register_module("child2", child.clone());
|
|
std::stringstream ss;
|
|
std::stringstream ss_data;
|
|
m._save_for_mobile(ss);
|
|
|
|
// load mobile module, save mobile named parameters
|
|
mobile::Module bc = _load_for_mobile(ss);
|
|
_save_parameters(bc.named_parameters(), ss_data);
|
|
|
|
// load back the named parameters, test is empty
|
|
auto mobile_params = _load_parameters(ss_data);
|
|
AT_ASSERT(mobile_params.size() == 0);
|
|
}
|
|
|
|
TEST(MobileTest, SaveParametersDefaultsToZip) {
|
|
// Save some empty parameters.
|
|
std::map<std::string, at::Tensor> empty_parameters;
|
|
std::stringstream ss_data;
|
|
_save_parameters(empty_parameters, ss_data);
|
|
|
|
// Verify that parameters were serialized to a ZIP container.
|
|
EXPECT_GE(ss_data.str().size(), 4);
|
|
EXPECT_EQ(ss_data.str()[0], 'P');
|
|
EXPECT_EQ(ss_data.str()[1], 'K');
|
|
EXPECT_EQ(ss_data.str()[2], '\x03');
|
|
EXPECT_EQ(ss_data.str()[3], '\x04');
|
|
}
|
|
|
|
TEST(MobileTest, SaveParametersCanUseFlatbuffer) {
|
|
// Save some empty parameters using flatbuffer.
|
|
std::map<std::string, at::Tensor> empty_parameters;
|
|
std::stringstream ss_data;
|
|
_save_parameters(empty_parameters, ss_data, /*use_flatbuffer=*/true);
|
|
|
|
// Verify that parameters were serialized to a flatbuffer. The flatbuffer
|
|
// magic bytes should be at offsets 4..7. The first four bytes contain an
|
|
// offset to the actual flatbuffer data.
|
|
EXPECT_GE(ss_data.str().size(), 8);
|
|
EXPECT_EQ(ss_data.str()[4], 'P');
|
|
EXPECT_EQ(ss_data.str()[5], 'T');
|
|
EXPECT_EQ(ss_data.str()[6], 'M');
|
|
EXPECT_EQ(ss_data.str()[7], 'F');
|
|
}
|
|
|
|
TEST(MobileTest, SaveLoadParametersUsingFlatbuffers) {
|
|
// Create some simple parameters to save.
|
|
std::map<std::string, at::Tensor> input_params;
|
|
input_params["four_by_ones"] = 4 * torch::ones({});
|
|
input_params["three_by_ones"] = 3 * torch::ones({});
|
|
|
|
// Serialize them using flatbuffers.
|
|
std::stringstream data;
|
|
_save_parameters(input_params, data, /*use_flatbuffer=*/true);
|
|
|
|
// The flatbuffer magic bytes should be at offsets 4..7.
|
|
EXPECT_EQ(data.str()[4], 'P');
|
|
EXPECT_EQ(data.str()[5], 'T');
|
|
EXPECT_EQ(data.str()[6], 'M');
|
|
EXPECT_EQ(data.str()[7], 'F');
|
|
|
|
// Read them back and check that they survived the trip.
|
|
auto output_params = _load_parameters(data);
|
|
EXPECT_EQ(output_params.size(), 2);
|
|
{
|
|
auto four_by_ones = 4 * torch::ones({});
|
|
EXPECT_EQ(
|
|
output_params["four_by_ones"].item<int>(), four_by_ones.item<int>());
|
|
}
|
|
{
|
|
auto three_by_ones = 3 * torch::ones({});
|
|
EXPECT_EQ(
|
|
output_params["three_by_ones"].item<int>(), three_by_ones.item<int>());
|
|
}
|
|
}
|
|
|
|
TEST(MobileTest, LoadParametersUnexpectedFormatShouldThrow) {
|
|
// Manually create some data that doesn't look like a ZIP or Flatbuffer file.
|
|
// Make sure it's longer than 8 bytes, since getFileFormat() needs that much
|
|
// data to detect the type.
|
|
std::stringstream bad_data;
|
|
bad_data << "abcd"
|
|
<< "efgh"
|
|
<< "ijkl";
|
|
|
|
// Loading parameters from it should throw an exception.
|
|
EXPECT_ANY_THROW(_load_parameters(bad_data));
|
|
}
|
|
|
|
TEST(MobileTest, LoadParametersEmptyDataShouldThrow) {
|
|
// Loading parameters from an empty data stream should throw an exception.
|
|
std::stringstream empty;
|
|
EXPECT_ANY_THROW(_load_parameters(empty));
|
|
}
|
|
|
|
TEST(MobileTest, LoadParametersMalformedFlatbuffer) {
|
|
// Manually create some data with Flatbuffer header.
|
|
std::stringstream bad_data;
|
|
bad_data << "PK\x03\x04PTMF\x00\x00"
|
|
<< "*}NV\xb3\xfa\xdf\x00pa";
|
|
|
|
// Loading parameters from it should throw an exception.
|
|
ASSERT_THROWS_WITH_MESSAGE(
|
|
_load_parameters(bad_data), "Malformed Flatbuffer module");
|
|
}
|
|
|
|
TEST(LiteTrainerTest, SGD) {
|
|
Module m("m");
|
|
m.register_parameter("foo", torch::ones({1}, at::requires_grad()), false);
|
|
m.define(R"(
|
|
def forward(self, x):
|
|
b = 1.0
|
|
return self.foo * x + b
|
|
)");
|
|
double learning_rate = 0.1, momentum = 0.1;
|
|
int n_epoc = 10;
|
|
// init: y = x + 1;
|
|
// target: y = 2 x + 1
|
|
std::vector<std::pair<Tensor, Tensor>> trainData{
|
|
{1 * torch::ones({1}), 3 * torch::ones({1})},
|
|
};
|
|
// Reference: Full jit and torch::optim::SGD
|
|
std::stringstream ms;
|
|
m.save(ms);
|
|
auto mm = load(ms);
|
|
std::vector<::at::Tensor> parameters;
|
|
for (auto parameter : mm.parameters()) {
|
|
parameters.emplace_back(parameter);
|
|
}
|
|
::torch::optim::SGD optimizer(
|
|
parameters, ::torch::optim::SGDOptions(learning_rate).momentum(momentum));
|
|
for (int epoc = 0; epoc < n_epoc; ++epoc) {
|
|
for (auto& data : trainData) {
|
|
auto source = data.first, targets = data.second;
|
|
optimizer.zero_grad();
|
|
std::vector<IValue> train_inputs{source};
|
|
auto output = mm.forward(train_inputs).toTensor();
|
|
auto loss = ::torch::l1_loss(output, targets);
|
|
loss.backward();
|
|
optimizer.step();
|
|
}
|
|
}
|
|
// Test: lite interpreter and torch::jit::mobile::SGD
|
|
std::stringstream ss;
|
|
m._save_for_mobile(ss);
|
|
mobile::Module bc = _load_for_mobile(ss);
|
|
std::vector<::at::Tensor> bc_parameters = bc.parameters();
|
|
::torch::jit::mobile::SGD bc_optimizer(
|
|
bc_parameters,
|
|
::torch::jit::mobile::SGDOptions(learning_rate).momentum(momentum));
|
|
for (int epoc = 0; epoc < n_epoc; ++epoc) {
|
|
for (auto& data : trainData) {
|
|
auto source = data.first, targets = data.second;
|
|
bc_optimizer.zero_grad();
|
|
std::vector<IValue> train_inputs{source};
|
|
auto output = bc.forward(train_inputs).toTensor();
|
|
auto loss = ::torch::l1_loss(output, targets);
|
|
loss.backward();
|
|
bc_optimizer.step();
|
|
}
|
|
}
|
|
AT_ASSERT(parameters[0].item<float>() == bc_parameters[0].item<float>());
|
|
}
|
|
|
|
namespace {
|
|
struct DummyDataset : torch::data::datasets::Dataset<DummyDataset, int> {
|
|
explicit DummyDataset(size_t size = 100) : size_(size) {}
|
|
|
|
int get(size_t index) override {
|
|
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
|
return 1 + index;
|
|
}
|
|
torch::optional<size_t> size() const override {
|
|
return size_;
|
|
}
|
|
|
|
size_t size_;
|
|
};
|
|
} // namespace
|
|
|
|
TEST(LiteTrainerTest, SequentialSampler) {
|
|
// test that sampler can be used with dataloader
|
|
const int kBatchSize = 10;
|
|
auto data_loader = torch::data::make_data_loader<mobile::SequentialSampler>(
|
|
DummyDataset(25), kBatchSize);
|
|
int i = 1;
|
|
for (const auto& batch : *data_loader) {
|
|
for (const auto& example : batch) {
|
|
AT_ASSERT(i == example);
|
|
i++;
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST(LiteTrainerTest, RandomSamplerReturnsIndicesInCorrectRange) {
|
|
mobile::RandomSampler sampler(10);
|
|
|
|
std::vector<size_t> indices = sampler.next(3).value();
|
|
for (auto i : indices) {
|
|
AT_ASSERT(i < 10);
|
|
}
|
|
|
|
indices = sampler.next(5).value();
|
|
for (auto i : indices) {
|
|
AT_ASSERT(i < 10);
|
|
}
|
|
|
|
indices = sampler.next(2).value();
|
|
for (auto i : indices) {
|
|
AT_ASSERT(i < 10);
|
|
}
|
|
|
|
AT_ASSERT(sampler.next(10).has_value() == false);
|
|
}
|
|
|
|
TEST(LiteTrainerTest, RandomSamplerReturnsLessValuesForLastBatch) {
|
|
mobile::RandomSampler sampler(5);
|
|
AT_ASSERT(sampler.next(3).value().size() == 3);
|
|
AT_ASSERT(sampler.next(100).value().size() == 2);
|
|
AT_ASSERT(sampler.next(2).has_value() == false);
|
|
}
|
|
|
|
TEST(LiteTrainerTest, RandomSamplerResetsWell) {
|
|
mobile::RandomSampler sampler(5);
|
|
AT_ASSERT(sampler.next(5).value().size() == 5);
|
|
AT_ASSERT(sampler.next(2).has_value() == false);
|
|
sampler.reset();
|
|
AT_ASSERT(sampler.next(5).value().size() == 5);
|
|
AT_ASSERT(sampler.next(2).has_value() == false);
|
|
}
|
|
|
|
TEST(LiteTrainerTest, RandomSamplerResetsWithNewSizeWell) {
|
|
mobile::RandomSampler sampler(5);
|
|
AT_ASSERT(sampler.next(5).value().size() == 5);
|
|
AT_ASSERT(sampler.next(2).has_value() == false);
|
|
sampler.reset(7);
|
|
AT_ASSERT(sampler.next(7).value().size() == 7);
|
|
AT_ASSERT(sampler.next(2).has_value() == false);
|
|
sampler.reset(3);
|
|
AT_ASSERT(sampler.next(3).value().size() == 3);
|
|
AT_ASSERT(sampler.next(2).has_value() == false);
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|