#include #include #include #include #include #include #include #include #include // Tests go in torch::jit namespace torch { namespace jit { void testLiteInterpreterUpsampleNearest2d() { Module m("m"); m.define(R"( def forward(self, input: Tensor, scale:float): return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale)) )"); std::vector inputs; inputs.emplace_back(torch::rand({1, 3, 128, 128})); inputs.emplace_back(at::Scalar(2.0)); auto ref = m.forward(inputs); std::stringstream ss; m._save_for_mobile(ss); mobile::Module bc = _load_for_mobile(ss); IValue res; res = bc.forward(inputs); auto resd = res.toTensor(); auto refd = ref.toTensor(); ASSERT_TRUE(resd.equal(refd)); } void testLiteInterpreterAdd() { Module m("m"); m.register_parameter("foo", torch::ones({}), false); // TODO: support default param val, which was pushed in // function schema's checkAndNormalizeInputs() // m.define(R"( // def add_it(self, x, b : int = 4): // return self.foo + x + b // )"); m.define(R"( def add_it(self, x): b = 4 return self.foo + x + b )"); std::vector inputs; auto minput = 5 * torch::ones({}); inputs.emplace_back(minput); auto ref = m.run_method("add_it", minput); std::stringstream ss; m._save_for_mobile(ss); mobile::Module bc = _load_for_mobile(ss); IValue res; for (int i = 0; i < 3; ++i) { auto bcinputs = inputs; res = bc.run_method("add_it", bcinputs); } auto resd = res.toTensor().item(); auto refd = ref.toTensor().item(); AT_ASSERT(resd == refd); } void testLiteInterpreterConv() { auto s = std::getenv("PYTORCH_TEST_WITH_TSAN"); if (s && strcmp(s, "1") == 0) return; std::vector inputs; Module m("m"); m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false); m.register_parameter("bias", torch::ones({20}), false); m.define(R"( def forward(self, input): return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True) )"); inputs.push_back(torch::ones({1, 1, 28, 28})); auto outputref = m.forward(inputs).toTensor(); std::stringstream ss; m._save_for_mobile(ss); mobile::Module bc = _load_for_mobile(ss); IValue res; for (int i = 0; i < 3; ++i) { res = bc.run_method("forward", inputs); } auto output = res.toTensor(); AT_ASSERT(outputref.dim() == output.dim()); AT_ASSERT(outputref[0][0][0][0].item() == output[0][0][0][0].item()); } void testLiteInterpreterInline() { Module m("m"); m.define(R"JIT( def foo1(self, x): return x + 1 def foo2(self, x): return self.foo1(x) + 2 def foo3(self, x): return self.foo2(x) + 3 )JIT"); std::stringstream ss; m._save_for_mobile(ss); mobile::Module bc = _load_for_mobile(ss); std::vector inputs({torch::ones({})}); auto output = bc.run_method("foo3", inputs); AT_ASSERT(output.toTensor().item() == 7.0); } void testLiteInterpreterTuple() { Module m("m"); m.define(R"JIT( def foo(self, x): return (1, 2, x + 3) def forward(self, x): tuple = self.foo(x) return tuple )JIT"); std::stringstream ss; m._save_for_mobile(ss); mobile::Module bc = _load_for_mobile(ss); std::vector inputs({torch::ones({})}); auto output = bc.run_method("forward", inputs); AT_ASSERT(output.toTuple()->elements()[1].toInt() == 2); } void testLiteInterpreterPrimOverload() { /* // temporarily disabled script::Module m("m"); m.define(R"JIT( def forward(self, x): result = [1, 2] result.append(3) return result )JIT"); std::stringstream ss; m._save_for_mobile(ss); mobile::Module bc = _load_for_mobile(ss); std::vector inputs({torch::ones({})}); auto output = bc.run_method("forward", inputs); AT_ASSERT(output.toIntList()[2] == 3); */ } void testLiteInterpreterPrim() { Module m("m"); m.define(R"JIT( def forward(self, x): return int(x) )JIT"); std::vector inputs; auto minput = 3.5 * torch::ones({}); inputs.emplace_back(minput); auto ref = m.run_method("forward", minput); std::stringstream ss; m._save_for_mobile(ss); mobile::Module bc = _load_for_mobile(ss); IValue res; for (int i = 0; i < 3; ++i) { auto bcinputs = inputs; res = bc.run_method("forward", bcinputs); } auto resi = res.toInt(); auto refi = ref.toInt(); AT_ASSERT(resi == refi); } void testLiteInterpreterLoadOrigJit() { Module m("m"); m.register_parameter("foo", torch::ones({}), false); m.define(R"( def forward(self, x): b = 4 return self.foo + x + b )"); std::stringstream ss; m.save(ss); ASSERT_THROWS_WITH(_load_for_mobile(ss), "file not found"); } void testLiteInterpreterWrongMethodName() { Module m("m"); m.register_parameter("foo", torch::ones({}), false); m.define(R"( def add(self, x): b = 4 return self.foo + x + b )"); std::stringstream ss; m._save_for_mobile(ss); mobile::Module bc = _load_for_mobile(ss); std::vector inputs; auto minput = 5 * torch::ones({}); inputs.emplace_back(minput); ASSERT_THROWS_WITH(bc.run_method("forward", inputs), "is not defined"); } void testLiteInterpreterParams() { Module m("m"); m.register_parameter("foo", torch::ones({1}, at::requires_grad()), false); m.define(R"( def forward(self, x): b = 1.0 return self.foo * x + b )"); double learning_rate = 0.1, momentum = 0.1; int n_epoc = 10; // init: y = x + 1; // target: y = 2 x + 1 std::vector> trainData{ {1 * torch::ones({1}), 3 * torch::ones({1})}, }; // Reference: Full jit std::stringstream ms; m.save(ms); auto mm = load(ms); // mm.train(); std::vector<::at::Tensor> parameters; for (auto parameter : mm.parameters()) { parameters.emplace_back(parameter); } ::torch::optim::SGD optimizer( parameters, ::torch::optim::SGDOptions(learning_rate).momentum(momentum)); for (int epoc = 0; epoc < n_epoc; ++epoc) { for (auto &data : trainData) { auto source = data.first, targets = data.second; optimizer.zero_grad(); std::vector train_inputs{source}; auto output = mm.forward(train_inputs).toTensor(); auto loss = ::torch::l1_loss(output, targets); loss.backward(); optimizer.step(); } } std::stringstream ss; m._save_for_mobile(ss); mobile::Module bc = _load_for_mobile(ss); std::vector<::at::Tensor> bc_parameters = bc.parameters(); ::torch::optim::SGD bc_optimizer( bc_parameters, ::torch::optim::SGDOptions(learning_rate).momentum(momentum)); for (int epoc = 0; epoc < n_epoc; ++epoc) { for (auto &data : trainData) { auto source = data.first, targets = data.second; bc_optimizer.zero_grad(); std::vector train_inputs{source}; auto output = bc.forward(train_inputs).toTensor(); auto loss = ::torch::l1_loss(output, targets); loss.backward(); bc_optimizer.step(); } } AT_ASSERT(parameters[0].item() == bc_parameters[0].item()); } void testLiteInterpreterSetState() { Module m("m"); m.register_parameter("foo", torch::ones({}), false); m.define(R"( def __getstate__(self): return self.foo + self.foo def __setstate__(self, a): self.foo = a def forward(self, x): b = 4 return self.foo + x + b )"); std::vector inputs; auto minput = 5 * torch::ones({}); inputs.emplace_back(minput); std::stringstream ms; m.save(ms); auto loaded_m = load(ms); auto ref = loaded_m.run_method("forward", minput); std::stringstream ss; m._save_for_mobile(ss); mobile::Module bc = _load_for_mobile(ss); IValue res; for (int i = 0; i < 3; ++i) { auto bcinputs = inputs; res = bc.run_method("forward", bcinputs); } auto resd = res.toTensor().item(); auto refd = ref.toTensor().item(); AT_ASSERT(resd == refd); } class TorchBindLiteInterpreterTestStruct : public torch::jit::CustomClassHolder { public: std::string get(at::Tensor t) { std::stringstream ss; ss << "Hello! Your tensor has "; ss << t.numel(); ss << " elements!"; return ss.str(); } }; void testLiteInterpreterBuiltinFunction() { script::Module m("m"); auto custom_class_obj = make_custom_class(); m.register_attribute("my_obj", custom_class_obj.type(), custom_class_obj); m.define(R"( def forward(self, x) -> str: return self.my_obj.get(x) )"); std::stringstream ss; m._save_for_mobile(ss); mobile::Module bc = _load_for_mobile(ss); auto res = bc.run_method("forward", std::vector{torch::zeros({3, 4})}); auto str = res.toStringRef(); std::string expected = "Hello! Your tensor has 12 elements!"; AT_ASSERT(str == expected); } namespace { static auto reg = torch::jit::class_( "_TorchScriptTesting", "_LiteInterpreterTest") .def("get", &TorchBindLiteInterpreterTestStruct::get) .def_pickle( // __getattr__ [](const c10::intrusive_ptr& self) -> int64_t { return 0; }, // __setattr__ [](int64_t state) { return c10::make_intrusive(); }); } // namespace } // namespace jit } // namespace torch