mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
add flatbuffer_loader and flatbuffer_serializer as BUCK target (#71463)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71463 title Test Plan: unittest Reviewed By: zhxchen17 Differential Revision: D33651339 fbshipit-source-id: 4bf325a40e263a441fd86bce560645ad0c1ebb23
This commit is contained in:
parent
d886380ede
commit
4cb02e62a6
|
|
@ -772,7 +772,7 @@ void testLiteModuleCompareResultTensors(
|
|||
AT_ASSERT(output.equal(outputref));
|
||||
}
|
||||
|
||||
void testDefaultArgsPinv(int num_args) {
|
||||
static void testDefaultArgsPinv(int num_args) {
|
||||
Module m("m");
|
||||
if (num_args == 1) {
|
||||
m.define(R"(
|
||||
|
|
@ -799,68 +799,6 @@ void testDefaultArgsPinv(int num_args) {
|
|||
inputs.emplace_back(input);
|
||||
testLiteModuleCompareResultTensors(m, inputs);
|
||||
}
|
||||
|
||||
void testDefaultArgsPinvWithOutArg(int num_args) {
|
||||
Module m("m");
|
||||
if (num_args == 1) {
|
||||
m.define(R"(
|
||||
def forward(self, input):
|
||||
return torch.linalg_pinv(input, out=input)
|
||||
)");
|
||||
} else if (num_args == 2) {
|
||||
m.define(R"(
|
||||
def forward(self, input):
|
||||
return torch.linalg_pinv(input, 1e-5, out=input)
|
||||
)");
|
||||
} else if (num_args == 3) {
|
||||
m.define(R"(
|
||||
def forward(self, input):
|
||||
return torch.linalg_pinv(input, 1e-5, True, out=input)
|
||||
)");
|
||||
}
|
||||
|
||||
const int N = 28;
|
||||
auto input = torch::range(1, N * N, 1);
|
||||
input[0] = 10000; // a more stable matrix
|
||||
input = input.view({N, N});
|
||||
auto ref = m.run_method("forward", input);
|
||||
TORCH_CHECK(!input.equal(torch::range(1, N * N, 1)));
|
||||
TORCH_CHECK(input.equal(ref.toTensor()));
|
||||
}
|
||||
|
||||
TEST(FlatbufferTest, DefaultArgsPinvWithOutArg) {
|
||||
// Test with different number of specified arguments + out arg.
|
||||
// Arguments not specified take default value.
|
||||
for (int num_args = 1; num_args <= 3; ++num_args) {
|
||||
testDefaultArgsPinvWithOutArg(num_args);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(FlatbufferTest, DefaultArgsWithOutArg) {
|
||||
Module m("m");
|
||||
m.define(R"(
|
||||
def forward(self, x, h):
|
||||
torch.add(x, h, out=x)
|
||||
)");
|
||||
|
||||
std::vector<IValue> inputs;
|
||||
auto input_x = 2 * torch::ones({});
|
||||
auto input_h = torch::ones({});
|
||||
auto ref = m.run_method("forward", input_x, input_h);
|
||||
|
||||
CompilationOptions options;
|
||||
mobile::Module bc = jitModuleToMobile(m, options);
|
||||
bc.run_method("forward", input_x, input_h);
|
||||
AT_ASSERT(input_x.equal(4 * torch::ones({})));
|
||||
|
||||
auto buff = save_mobile_module_to_bytes(bc);
|
||||
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());
|
||||
auto input_x2 = 2 * torch::ones({});
|
||||
auto input_h2 = torch::ones({});
|
||||
m.run_method("forward", input_x2, input_h2);
|
||||
bc2.run_method("forward", input_x2, input_h2);
|
||||
AT_ASSERT(input_x2.equal(4 * torch::ones({})));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
#if !defined FB_XPLAT_BUILD
|
||||
|
|
@ -962,6 +900,68 @@ TEST(FlatbufferTest, DefaultArgsTensorinvSpecifyDefault) {
|
|||
testLiteModuleCompareResultTensors(m, inputs);
|
||||
}
|
||||
|
||||
static void testDefaultArgsPinvWithOutArg(int num_args) {
|
||||
Module m("m");
|
||||
if (num_args == 1) {
|
||||
m.define(R"(
|
||||
def forward(self, input):
|
||||
return torch.linalg_pinv(input, out=input)
|
||||
)");
|
||||
} else if (num_args == 2) {
|
||||
m.define(R"(
|
||||
def forward(self, input):
|
||||
return torch.linalg_pinv(input, 1e-5, out=input)
|
||||
)");
|
||||
} else if (num_args == 3) {
|
||||
m.define(R"(
|
||||
def forward(self, input):
|
||||
return torch.linalg_pinv(input, 1e-5, True, out=input)
|
||||
)");
|
||||
}
|
||||
|
||||
const int N = 28;
|
||||
auto input = torch::range(1, N * N, 1);
|
||||
input[0] = 10000; // a more stable matrix
|
||||
input = input.view({N, N});
|
||||
auto ref = m.run_method("forward", input);
|
||||
TORCH_CHECK(!input.equal(torch::range(1, N * N, 1)));
|
||||
TORCH_CHECK(input.equal(ref.toTensor()));
|
||||
}
|
||||
|
||||
TEST(FlatbufferTest, DefaultArgsPinvWithOutArg) {
|
||||
// Test with different number of specified arguments + out arg.
|
||||
// Arguments not specified take default value.
|
||||
for (int num_args = 1; num_args <= 3; ++num_args) {
|
||||
testDefaultArgsPinvWithOutArg(num_args);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(FlatbufferTest, DefaultArgsWithOutArg) {
|
||||
Module m("m");
|
||||
m.define(R"(
|
||||
def forward(self, x, h):
|
||||
torch.add(x, h, out=x)
|
||||
)");
|
||||
|
||||
std::vector<IValue> inputs;
|
||||
auto input_x = 2 * torch::ones({});
|
||||
auto input_h = torch::ones({});
|
||||
auto ref = m.run_method("forward", input_x, input_h);
|
||||
|
||||
CompilationOptions options;
|
||||
mobile::Module bc = jitModuleToMobile(m, options);
|
||||
bc.run_method("forward", input_x, input_h);
|
||||
AT_ASSERT(input_x.equal(4 * torch::ones({})));
|
||||
|
||||
auto buff = save_mobile_module_to_bytes(bc);
|
||||
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());
|
||||
auto input_x2 = 2 * torch::ones({});
|
||||
auto input_h2 = torch::ones({});
|
||||
m.run_method("forward", input_x2, input_h2);
|
||||
bc2.run_method("forward", input_x2, input_h2);
|
||||
AT_ASSERT(input_x2.equal(4 * torch::ones({})));
|
||||
}
|
||||
|
||||
#endif // !defined(FB_XPLAT_BUILD)
|
||||
|
||||
namespace {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user