pytorch/test/cpp/jit/test_lite_interpreter.cpp
Martin Yuan f362cd510d Move prim ops from JIT registration to C10 (#30612)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30612

The first version to move prim ops to c10 registration. After the reviewers are fine with the initial changes, more operators will be moved in the same style.

Test Plan: Imported from OSS

Differential Revision: D19237648

Pulled By: iseeyuan

fbshipit-source-id: c5a519604efffb80564a556536f17d829f71d9f9
2020-01-04 13:47:44 -08:00

158 lines
4.1 KiB
C++

#include <test/cpp/jit/test_base.h>
#include <torch/csrc/jit/script/module.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/import.h>
// Tests go in torch::jit
namespace torch {
namespace jit {
void testLiteInterpreterAdd() {
script::Module m("m");
m.register_parameter("foo", torch::ones({}), false);
// TODO: support default param val, which was pushed in
// function schema's checkAndNormalizeInputs()
// m.define(R"(
// def add_it(self, x, b : int = 4):
// return self.foo + x + b
// )");
m.define(R"(
def add_it(self, x):
b = 4
return self.foo + x + b
)");
std::vector<IValue> inputs;
auto minput = 5 * torch::ones({});
inputs.emplace_back(minput);
auto ref = m.run_method("add_it", minput);
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
IValue res;
for (int i = 0; i < 3; ++i) {
auto bcinputs = inputs;
res = bc.run_method("add_it", bcinputs);
}
auto resd = res.toTensor().item<float>();
auto refd = ref.toTensor().item<float>();
AT_ASSERT(resd == refd);
}
void testLiteInterpreterConv() {
auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
if (s && strcmp(s, "1") == 0)
return;
std::vector<torch::jit::IValue> inputs;
script::Module m("m");
m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
m.register_parameter("bias", torch::ones({20}), false);
m.define(R"(
def forward(self, input):
return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True)
)");
inputs.push_back(torch::ones({1, 1, 28, 28}));
auto outputref = m.forward(inputs).toTensor();
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
IValue res;
for (int i = 0; i < 3; ++i) {
res = bc.run_method("forward", inputs);
}
auto output = res.toTensor();
AT_ASSERT(outputref.dim() == output.dim());
AT_ASSERT(outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
}
void testLiteInterpreterInline() {
script::Module m("m");
m.define(R"JIT(
def foo1(self, x):
return x + 1
def foo2(self, x):
return self.foo1(x) + 2
def foo3(self, x):
return self.foo2(x) + 3
)JIT");
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
std::vector<torch::jit::IValue> inputs({torch::ones({})});
auto output = bc.run_method("foo3", inputs);
AT_ASSERT(output.toTensor().item<float>() == 7.0);
}
void testLiteInterpreterTuple() {
script::Module m("m");
m.define(R"JIT(
def foo(self, x):
return (1, 2, x + 3)
def forward(self, x):
tuple = self.foo(x)
return tuple
)JIT");
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
std::vector<torch::jit::IValue> inputs({torch::ones({})});
auto output = bc.run_method("forward", inputs);
AT_ASSERT(output.toTuple()->elements()[1].toInt() == 2);
}
void testLiteInterpreterPrimOverload() {
script::Module m("m");
m.define(R"JIT(
def forward(self, x):
result = [1, 2]
result.append(3)
return result
)JIT");
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
std::vector<torch::jit::IValue> inputs({torch::ones({})});
auto output = bc.run_method("forward", inputs);
AT_ASSERT(output.toIntList()[2] == 3);
}
void testLiteInterpreterPrim() {
script::Module m("m");
m.define(R"JIT(
def forward(self, x):
return int(x)
)JIT");
std::vector<IValue> inputs;
auto minput = 3.5 * torch::ones({});
inputs.emplace_back(minput);
auto ref = m.run_method("forward", minput);
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
IValue res;
for (int i = 0; i < 3; ++i) {
auto bcinputs = inputs;
res = bc.run_method("forward", bcinputs);
}
auto resi = res.toInt();
auto refi = ref.toInt();
AT_ASSERT(resi == refi);
}
} // namespace torch
} // namespace jit