From 21b697b64674362f523c10b0bb4c21c3038b7f13 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Wed, 19 Jan 2022 20:47:09 -0800 Subject: [PATCH] add flatbuffer_loader and flatbuffer_serializer as BUCK target (#71463) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71463 title Test Plan: unittest Reviewed By: zhxchen17 Differential Revision: D33651339 fbshipit-source-id: 4bf325a40e263a441fd86bce560645ad0c1ebb23 (cherry picked from commit 4cb02e62a68f338e3388ad09276ced9b8f4cdcb1) --- test/cpp/jit/test_flatbuffer.cpp | 126 +++++++++++++++---------------- 1 file changed, 63 insertions(+), 63 deletions(-) diff --git a/test/cpp/jit/test_flatbuffer.cpp b/test/cpp/jit/test_flatbuffer.cpp index 73a84297f44..13533fa12be 100644 --- a/test/cpp/jit/test_flatbuffer.cpp +++ b/test/cpp/jit/test_flatbuffer.cpp @@ -772,7 +772,7 @@ void testLiteModuleCompareResultTensors( AT_ASSERT(output.equal(outputref)); } -void testDefaultArgsPinv(int num_args) { +static void testDefaultArgsPinv(int num_args) { Module m("m"); if (num_args == 1) { m.define(R"( @@ -799,68 +799,6 @@ void testDefaultArgsPinv(int num_args) { inputs.emplace_back(input); testLiteModuleCompareResultTensors(m, inputs); } - -void testDefaultArgsPinvWithOutArg(int num_args) { - Module m("m"); - if (num_args == 1) { - m.define(R"( - def forward(self, input): - return torch.linalg_pinv(input, out=input) - )"); - } else if (num_args == 2) { - m.define(R"( - def forward(self, input): - return torch.linalg_pinv(input, 1e-5, out=input) - )"); - } else if (num_args == 3) { - m.define(R"( - def forward(self, input): - return torch.linalg_pinv(input, 1e-5, True, out=input) - )"); - } - - const int N = 28; - auto input = torch::range(1, N * N, 1); - input[0] = 10000; // a more stable matrix - input = input.view({N, N}); - auto ref = m.run_method("forward", input); - TORCH_CHECK(!input.equal(torch::range(1, N * N, 1))); - TORCH_CHECK(input.equal(ref.toTensor())); -} - -TEST(FlatbufferTest, DefaultArgsPinvWithOutArg) { - // Test with different number of specified arguments + out arg. - // Arguments not specified take default value. - for (int num_args = 1; num_args <= 3; ++num_args) { - testDefaultArgsPinvWithOutArg(num_args); - } -} - -TEST(FlatbufferTest, DefaultArgsWithOutArg) { - Module m("m"); - m.define(R"( - def forward(self, x, h): - torch.add(x, h, out=x) - )"); - - std::vector inputs; - auto input_x = 2 * torch::ones({}); - auto input_h = torch::ones({}); - auto ref = m.run_method("forward", input_x, input_h); - - CompilationOptions options; - mobile::Module bc = jitModuleToMobile(m, options); - bc.run_method("forward", input_x, input_h); - AT_ASSERT(input_x.equal(4 * torch::ones({}))); - - auto buff = save_mobile_module_to_bytes(bc); - mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size()); - auto input_x2 = 2 * torch::ones({}); - auto input_h2 = torch::ones({}); - m.run_method("forward", input_x2, input_h2); - bc2.run_method("forward", input_x2, input_h2); - AT_ASSERT(input_x2.equal(4 * torch::ones({}))); -} } // namespace #if !defined FB_XPLAT_BUILD @@ -962,6 +900,68 @@ TEST(FlatbufferTest, DefaultArgsTensorinvSpecifyDefault) { testLiteModuleCompareResultTensors(m, inputs); } +static void testDefaultArgsPinvWithOutArg(int num_args) { + Module m("m"); + if (num_args == 1) { + m.define(R"( + def forward(self, input): + return torch.linalg_pinv(input, out=input) + )"); + } else if (num_args == 2) { + m.define(R"( + def forward(self, input): + return torch.linalg_pinv(input, 1e-5, out=input) + )"); + } else if (num_args == 3) { + m.define(R"( + def forward(self, input): + return torch.linalg_pinv(input, 1e-5, True, out=input) + )"); + } + + const int N = 28; + auto input = torch::range(1, N * N, 1); + input[0] = 10000; // a more stable matrix + input = input.view({N, N}); + auto ref = m.run_method("forward", input); + TORCH_CHECK(!input.equal(torch::range(1, N * N, 1))); + TORCH_CHECK(input.equal(ref.toTensor())); +} + +TEST(FlatbufferTest, DefaultArgsPinvWithOutArg) { + // Test with different number of specified arguments + out arg. + // Arguments not specified take default value. + for (int num_args = 1; num_args <= 3; ++num_args) { + testDefaultArgsPinvWithOutArg(num_args); + } +} + +TEST(FlatbufferTest, DefaultArgsWithOutArg) { + Module m("m"); + m.define(R"( + def forward(self, x, h): + torch.add(x, h, out=x) + )"); + + std::vector inputs; + auto input_x = 2 * torch::ones({}); + auto input_h = torch::ones({}); + auto ref = m.run_method("forward", input_x, input_h); + + CompilationOptions options; + mobile::Module bc = jitModuleToMobile(m, options); + bc.run_method("forward", input_x, input_h); + AT_ASSERT(input_x.equal(4 * torch::ones({}))); + + auto buff = save_mobile_module_to_bytes(bc); + mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size()); + auto input_x2 = 2 * torch::ones({}); + auto input_h2 = torch::ones({}); + m.run_method("forward", input_x2, input_h2); + bc2.run_method("forward", input_x2, input_h2); + AT_ASSERT(input_x2.equal(4 * torch::ones({}))); +} + #endif // !defined(FB_XPLAT_BUILD) namespace {