mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33294 1. Serialize bytecode of __setstate__ and run it when loading the model. 2. One use case is quantization. To test this use case a few operators are registered temporarily for lite interpreter. The "_" prefix registration will be removed when the operators are all migrated to mobile. Test Plan: Imported from OSS Differential Revision: D20162898 Pulled By: iseeyuan fbshipit-source-id: 7a3180807bf38fbce594d86993896861f12bb58c
309 lines
8.3 KiB
C++
309 lines
8.3 KiB
C++
#include <test/cpp/jit/test_base.h>
|
|
#include <torch/csrc/jit/api/module.h>
|
|
#include <torch/csrc/autograd/generated/variable_factories.h>
|
|
#include <torch/csrc/jit/mobile/import.h>
|
|
#include <torch/csrc/jit/mobile/module.h>
|
|
#include <torch/csrc/jit/serialization/import.h>
|
|
#include <torch/torch.h>
|
|
#include <c10/core/TensorOptions.h>
|
|
|
|
// Tests go in torch::jit
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
void testLiteInterpreterUpsampleNearest2d() {
|
|
script::Module m("m");
|
|
m.define(R"(
|
|
def forward(self, input: Tensor, scale:float):
|
|
return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
|
|
)");
|
|
|
|
std::vector<IValue> inputs;
|
|
inputs.emplace_back(torch::rand({1, 3, 128, 128}));
|
|
inputs.emplace_back(at::Scalar(2.0));
|
|
auto ref = m.forward(inputs);
|
|
|
|
std::stringstream ss;
|
|
m._save_for_mobile(ss);
|
|
mobile::Module bc = _load_for_mobile(ss);
|
|
IValue res;
|
|
res = bc.forward(inputs);
|
|
|
|
auto resd = res.toTensor();
|
|
auto refd = ref.toTensor();
|
|
ASSERT_TRUE(resd.equal(refd));
|
|
}
|
|
|
|
void testLiteInterpreterAdd() {
|
|
script::Module m("m");
|
|
m.register_parameter("foo", torch::ones({}), false);
|
|
// TODO: support default param val, which was pushed in
|
|
// function schema's checkAndNormalizeInputs()
|
|
// m.define(R"(
|
|
// def add_it(self, x, b : int = 4):
|
|
// return self.foo + x + b
|
|
// )");
|
|
m.define(R"(
|
|
def add_it(self, x):
|
|
b = 4
|
|
return self.foo + x + b
|
|
)");
|
|
|
|
std::vector<IValue> inputs;
|
|
auto minput = 5 * torch::ones({});
|
|
inputs.emplace_back(minput);
|
|
auto ref = m.run_method("add_it", minput);
|
|
|
|
std::stringstream ss;
|
|
m._save_for_mobile(ss);
|
|
mobile::Module bc = _load_for_mobile(ss);
|
|
IValue res;
|
|
for (int i = 0; i < 3; ++i) {
|
|
auto bcinputs = inputs;
|
|
res = bc.run_method("add_it", bcinputs);
|
|
}
|
|
|
|
auto resd = res.toTensor().item<float>();
|
|
auto refd = ref.toTensor().item<float>();
|
|
AT_ASSERT(resd == refd);
|
|
}
|
|
|
|
void testLiteInterpreterConv() {
|
|
auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
|
|
if (s && strcmp(s, "1") == 0)
|
|
return;
|
|
|
|
std::vector<torch::jit::IValue> inputs;
|
|
|
|
script::Module m("m");
|
|
m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
|
|
m.register_parameter("bias", torch::ones({20}), false);
|
|
m.define(R"(
|
|
def forward(self, input):
|
|
return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True)
|
|
)");
|
|
|
|
inputs.push_back(torch::ones({1, 1, 28, 28}));
|
|
|
|
auto outputref = m.forward(inputs).toTensor();
|
|
|
|
std::stringstream ss;
|
|
m._save_for_mobile(ss);
|
|
mobile::Module bc = _load_for_mobile(ss);
|
|
IValue res;
|
|
for (int i = 0; i < 3; ++i) {
|
|
res = bc.run_method("forward", inputs);
|
|
}
|
|
auto output = res.toTensor();
|
|
AT_ASSERT(outputref.dim() == output.dim());
|
|
AT_ASSERT(outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
|
|
}
|
|
|
|
void testLiteInterpreterInline() {
|
|
script::Module m("m");
|
|
m.define(R"JIT(
|
|
def foo1(self, x):
|
|
return x + 1
|
|
|
|
def foo2(self, x):
|
|
return self.foo1(x) + 2
|
|
|
|
def foo3(self, x):
|
|
return self.foo2(x) + 3
|
|
)JIT");
|
|
std::stringstream ss;
|
|
m._save_for_mobile(ss);
|
|
mobile::Module bc = _load_for_mobile(ss);
|
|
std::vector<torch::jit::IValue> inputs({torch::ones({})});
|
|
auto output = bc.run_method("foo3", inputs);
|
|
AT_ASSERT(output.toTensor().item<float>() == 7.0);
|
|
}
|
|
|
|
void testLiteInterpreterTuple() {
|
|
script::Module m("m");
|
|
m.define(R"JIT(
|
|
def foo(self, x):
|
|
return (1, 2, x + 3)
|
|
|
|
def forward(self, x):
|
|
tuple = self.foo(x)
|
|
return tuple
|
|
)JIT");
|
|
std::stringstream ss;
|
|
m._save_for_mobile(ss);
|
|
mobile::Module bc = _load_for_mobile(ss);
|
|
std::vector<torch::jit::IValue> inputs({torch::ones({})});
|
|
auto output = bc.run_method("forward", inputs);
|
|
AT_ASSERT(output.toTuple()->elements()[1].toInt() == 2);
|
|
}
|
|
|
|
void testLiteInterpreterPrimOverload() {
|
|
script::Module m("m");
|
|
m.define(R"JIT(
|
|
def forward(self, x):
|
|
result = [1, 2]
|
|
result.append(3)
|
|
return result
|
|
)JIT");
|
|
std::stringstream ss;
|
|
m._save_for_mobile(ss);
|
|
mobile::Module bc = _load_for_mobile(ss);
|
|
std::vector<torch::jit::IValue> inputs({torch::ones({})});
|
|
auto output = bc.run_method("forward", inputs);
|
|
AT_ASSERT(output.toIntList()[2] == 3);
|
|
}
|
|
|
|
void testLiteInterpreterPrim() {
|
|
script::Module m("m");
|
|
m.define(R"JIT(
|
|
def forward(self, x):
|
|
return int(x)
|
|
)JIT");
|
|
|
|
std::vector<IValue> inputs;
|
|
auto minput = 3.5 * torch::ones({});
|
|
inputs.emplace_back(minput);
|
|
auto ref = m.run_method("forward", minput);
|
|
|
|
std::stringstream ss;
|
|
m._save_for_mobile(ss);
|
|
mobile::Module bc = _load_for_mobile(ss);
|
|
IValue res;
|
|
for (int i = 0; i < 3; ++i) {
|
|
auto bcinputs = inputs;
|
|
res = bc.run_method("forward", bcinputs);
|
|
}
|
|
|
|
auto resi = res.toInt();
|
|
auto refi = ref.toInt();
|
|
AT_ASSERT(resi == refi);
|
|
}
|
|
|
|
void testLiteInterpreterLoadOrigJit() {
|
|
script::Module m("m");
|
|
m.register_parameter("foo", torch::ones({}), false);
|
|
m.define(R"(
|
|
def forward(self, x):
|
|
b = 4
|
|
return self.foo + x + b
|
|
)");
|
|
std::stringstream ss;
|
|
m.save(ss);
|
|
ASSERT_THROWS_WITH(_load_for_mobile(ss), "file not found");
|
|
}
|
|
|
|
void testLiteInterpreterWrongMethodName() {
|
|
script::Module m("m");
|
|
m.register_parameter("foo", torch::ones({}), false);
|
|
m.define(R"(
|
|
def add(self, x):
|
|
b = 4
|
|
return self.foo + x + b
|
|
)");
|
|
std::stringstream ss;
|
|
m._save_for_mobile(ss);
|
|
mobile::Module bc = _load_for_mobile(ss);
|
|
std::vector<IValue> inputs;
|
|
auto minput = 5 * torch::ones({});
|
|
inputs.emplace_back(minput);
|
|
ASSERT_THROWS_WITH(bc.run_method("forward", inputs), "is not defined");
|
|
}
|
|
|
|
void testLiteInterpreterParams() {
|
|
script::Module m("m");
|
|
m.register_parameter("foo", torch::ones({1}, at::requires_grad()), false);
|
|
m.define(R"(
|
|
def forward(self, x):
|
|
b = 1.0
|
|
return self.foo * x + b
|
|
)");
|
|
double learning_rate = 0.1, momentum = 0.1;
|
|
int n_epoc = 10;
|
|
// init: y = x + 1;
|
|
// target: y = 2 x + 1
|
|
std::vector<std::pair<Tensor, Tensor>> trainData{
|
|
{1 * torch::ones({1}), 3 * torch::ones({1})},
|
|
};
|
|
// Reference: Full jit
|
|
std::stringstream ms;
|
|
m.save(ms);
|
|
auto mm = load(ms);
|
|
// mm.train();
|
|
std::vector<::at::Tensor> parameters;
|
|
for (auto parameter : mm.parameters()) {
|
|
parameters.emplace_back(parameter);
|
|
}
|
|
::torch::optim::SGD optimizer(
|
|
parameters,
|
|
::torch::optim::SGDOptions(learning_rate).momentum(momentum));
|
|
for (int epoc = 0; epoc < n_epoc; ++epoc) {
|
|
for (auto &data : trainData) {
|
|
auto source = data.first, targets = data.second;
|
|
optimizer.zero_grad();
|
|
std::vector<IValue> train_inputs{source};
|
|
auto output = mm.forward(train_inputs).toTensor();
|
|
auto loss = ::torch::l1_loss(output, targets);
|
|
loss.backward();
|
|
optimizer.step();
|
|
}
|
|
}
|
|
std::stringstream ss;
|
|
m._save_for_mobile(ss);
|
|
mobile::Module bc = _load_for_mobile(ss);
|
|
std::vector<::at::Tensor> bc_parameters = bc.parameters();
|
|
::torch::optim::SGD bc_optimizer(
|
|
bc_parameters,
|
|
::torch::optim::SGDOptions(learning_rate).momentum(momentum));
|
|
for (int epoc = 0; epoc < n_epoc; ++epoc) {
|
|
for (auto &data : trainData) {
|
|
auto source = data.first, targets = data.second;
|
|
bc_optimizer.zero_grad();
|
|
std::vector<IValue> train_inputs{source};
|
|
auto output = bc.forward(train_inputs).toTensor();
|
|
auto loss = ::torch::l1_loss(output, targets);
|
|
loss.backward();
|
|
bc_optimizer.step();
|
|
}
|
|
}
|
|
AT_ASSERT(parameters[0].item<float>() == bc_parameters[0].item<float>());
|
|
}
|
|
|
|
void testLiteInterpreterSetState() {
|
|
script::Module m("m");
|
|
m.register_parameter("foo", torch::ones({}), false);
|
|
m.define(R"(
|
|
def __getstate__(self):
|
|
return self.foo + self.foo
|
|
def __setstate__(self, a):
|
|
self.foo = a
|
|
def forward(self, x):
|
|
b = 4
|
|
return self.foo + x + b
|
|
)");
|
|
|
|
std::vector<IValue> inputs;
|
|
auto minput = 5 * torch::ones({});
|
|
inputs.emplace_back(minput);
|
|
|
|
std::stringstream ms;
|
|
m.save(ms);
|
|
auto loaded_m = load(ms);
|
|
auto ref = loaded_m.run_method("forward", minput);
|
|
|
|
std::stringstream ss;
|
|
m._save_for_mobile(ss);
|
|
mobile::Module bc = _load_for_mobile(ss);
|
|
IValue res;
|
|
for (int i = 0; i < 3; ++i) {
|
|
auto bcinputs = inputs;
|
|
res = bc.run_method("forward", bcinputs);
|
|
}
|
|
|
|
auto resd = res.toTensor().item<float>();
|
|
auto refd = ref.toTensor().item<float>();
|
|
AT_ASSERT(resd == refd);
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|