add flatbuffer_loader and flatbuffer_serializer as BUCK target (#71463)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71463

title

Test Plan: unittest

Reviewed By: zhxchen17

Differential Revision: D33651339

fbshipit-source-id: 4bf325a40e263a441fd86bce560645ad0c1ebb23
This commit is contained in:
Han Qi 2022-01-19 20:47:09 -08:00 committed by Facebook GitHub Bot
parent d886380ede
commit 4cb02e62a6

View File

@ -772,7 +772,7 @@ void testLiteModuleCompareResultTensors(
AT_ASSERT(output.equal(outputref));
}
void testDefaultArgsPinv(int num_args) {
static void testDefaultArgsPinv(int num_args) {
Module m("m");
if (num_args == 1) {
m.define(R"(
@ -799,68 +799,6 @@ void testDefaultArgsPinv(int num_args) {
inputs.emplace_back(input);
testLiteModuleCompareResultTensors(m, inputs);
}
void testDefaultArgsPinvWithOutArg(int num_args) {
Module m("m");
if (num_args == 1) {
m.define(R"(
def forward(self, input):
return torch.linalg_pinv(input, out=input)
)");
} else if (num_args == 2) {
m.define(R"(
def forward(self, input):
return torch.linalg_pinv(input, 1e-5, out=input)
)");
} else if (num_args == 3) {
m.define(R"(
def forward(self, input):
return torch.linalg_pinv(input, 1e-5, True, out=input)
)");
}
const int N = 28;
auto input = torch::range(1, N * N, 1);
input[0] = 10000; // a more stable matrix
input = input.view({N, N});
auto ref = m.run_method("forward", input);
TORCH_CHECK(!input.equal(torch::range(1, N * N, 1)));
TORCH_CHECK(input.equal(ref.toTensor()));
}
TEST(FlatbufferTest, DefaultArgsPinvWithOutArg) {
// Test with different number of specified arguments + out arg.
// Arguments not specified take default value.
for (int num_args = 1; num_args <= 3; ++num_args) {
testDefaultArgsPinvWithOutArg(num_args);
}
}
TEST(FlatbufferTest, DefaultArgsWithOutArg) {
Module m("m");
m.define(R"(
def forward(self, x, h):
torch.add(x, h, out=x)
)");
std::vector<IValue> inputs;
auto input_x = 2 * torch::ones({});
auto input_h = torch::ones({});
auto ref = m.run_method("forward", input_x, input_h);
CompilationOptions options;
mobile::Module bc = jitModuleToMobile(m, options);
bc.run_method("forward", input_x, input_h);
AT_ASSERT(input_x.equal(4 * torch::ones({})));
auto buff = save_mobile_module_to_bytes(bc);
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());
auto input_x2 = 2 * torch::ones({});
auto input_h2 = torch::ones({});
m.run_method("forward", input_x2, input_h2);
bc2.run_method("forward", input_x2, input_h2);
AT_ASSERT(input_x2.equal(4 * torch::ones({})));
}
} // namespace
#if !defined FB_XPLAT_BUILD
@ -962,6 +900,68 @@ TEST(FlatbufferTest, DefaultArgsTensorinvSpecifyDefault) {
testLiteModuleCompareResultTensors(m, inputs);
}
static void testDefaultArgsPinvWithOutArg(int num_args) {
Module m("m");
if (num_args == 1) {
m.define(R"(
def forward(self, input):
return torch.linalg_pinv(input, out=input)
)");
} else if (num_args == 2) {
m.define(R"(
def forward(self, input):
return torch.linalg_pinv(input, 1e-5, out=input)
)");
} else if (num_args == 3) {
m.define(R"(
def forward(self, input):
return torch.linalg_pinv(input, 1e-5, True, out=input)
)");
}
const int N = 28;
auto input = torch::range(1, N * N, 1);
input[0] = 10000; // a more stable matrix
input = input.view({N, N});
auto ref = m.run_method("forward", input);
TORCH_CHECK(!input.equal(torch::range(1, N * N, 1)));
TORCH_CHECK(input.equal(ref.toTensor()));
}
TEST(FlatbufferTest, DefaultArgsPinvWithOutArg) {
// Test with different number of specified arguments + out arg.
// Arguments not specified take default value.
for (int num_args = 1; num_args <= 3; ++num_args) {
testDefaultArgsPinvWithOutArg(num_args);
}
}
TEST(FlatbufferTest, DefaultArgsWithOutArg) {
Module m("m");
m.define(R"(
def forward(self, x, h):
torch.add(x, h, out=x)
)");
std::vector<IValue> inputs;
auto input_x = 2 * torch::ones({});
auto input_h = torch::ones({});
auto ref = m.run_method("forward", input_x, input_h);
CompilationOptions options;
mobile::Module bc = jitModuleToMobile(m, options);
bc.run_method("forward", input_x, input_h);
AT_ASSERT(input_x.equal(4 * torch::ones({})));
auto buff = save_mobile_module_to_bytes(bc);
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());
auto input_x2 = 2 * torch::ones({});
auto input_h2 = torch::ones({});
m.run_method("forward", input_x2, input_h2);
bc2.run_method("forward", input_x2, input_h2);
AT_ASSERT(input_x2.equal(4 * torch::ones({})));
}
#endif // !defined(FB_XPLAT_BUILD)
namespace {