mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/37251 This was broken by recent changes to how we serialize with type tags. We save a name (like `Dict[str, MyNamedTuple]`) and then relied on the mobile type parser to resolve that name back into a set of types. This doesn't work for any NamedTypes as the mobile type parser doesn't know how to resolve those. The unpickler allows the caller to inject a type resolver in for this purpose, use that so that when importing in a non-mobile environment you get the right results. A second problem also had to be fixed: the SourceImporter type loader would only load named types directly (e.g. `MyNamedTuple`) and choked if it was a general type that contained a named tupe (e.g. `List[MyNamedTuple]`). Fixed that and renamed `loadNamedType` to `loadType` for clarity. Test Plan: Imported from OSS Differential Revision: D21235213 Pulled By: suo fbshipit-source-id: 16db0f4c5e91a890d67a8687cc8ababa6b94b0f4
80 lines
2.2 KiB
C++
80 lines
2.2 KiB
C++
|
|
#include <test/cpp/jit/test_base.h>
|
|
#include <test/cpp/jit/test_utils.h>
|
|
|
|
#include <ATen/core/qualified_name.h>
|
|
#include <torch/csrc/jit/frontend/resolver.h>
|
|
#include <torch/csrc/jit/serialization/import.h>
|
|
#include <torch/csrc/jit/serialization/import_source.h>
|
|
#include <torch/torch.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
static const std::vector<std::string> subMethodSrcs = {R"JIT(
|
|
def one(self, x: Tensor, y: Tensor) -> Tensor:
|
|
return x + y + 1
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return x
|
|
)JIT"};
|
|
static const auto parentForward = R"JIT(
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return self.subMod.forward(x)
|
|
)JIT";
|
|
|
|
static const auto moduleInterfaceSrc = R"JIT(
|
|
class OneForward(ModuleInterface):
|
|
def one(self, x: Tensor, y: Tensor) -> Tensor:
|
|
pass
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
pass
|
|
)JIT";
|
|
|
|
static void import_libs(
|
|
std::shared_ptr<CompilationUnit> cu,
|
|
const std::string& class_name,
|
|
const std::shared_ptr<Source>& src,
|
|
const std::vector<at::Tensor>& tensor_table) {
|
|
SourceImporter si(
|
|
cu,
|
|
&tensor_table,
|
|
[&](const std::string& name) -> std::shared_ptr<Source> { return src; },
|
|
/*version=*/2);
|
|
si.loadType(QualifiedName(class_name));
|
|
}
|
|
|
|
void testModuleInterfaceSerialization() {
|
|
auto cu = std::make_shared<CompilationUnit>();
|
|
Module parentMod("parentMod", cu);
|
|
Module subMod("subMod", cu);
|
|
|
|
std::vector<at::Tensor> constantTable;
|
|
import_libs(
|
|
cu,
|
|
"__torch__.OneForward",
|
|
std::make_shared<Source>(moduleInterfaceSrc),
|
|
constantTable);
|
|
|
|
for (const std::string& method : subMethodSrcs) {
|
|
subMod.define(method, nativeResolver());
|
|
}
|
|
parentMod.register_attribute(
|
|
"subMod",
|
|
cu->get_interface("__torch__.OneForward"),
|
|
subMod._ivalue(),
|
|
/*is_parameter=*/false);
|
|
parentMod.define(parentForward, nativeResolver());
|
|
ASSERT_TRUE(parentMod.hasattr("subMod"));
|
|
std::stringstream ss;
|
|
parentMod.save(ss);
|
|
Module reloaded_mod = jit::load(ss);
|
|
ASSERT_TRUE(reloaded_mod.hasattr("subMod"));
|
|
InterfaceTypePtr submodType =
|
|
reloaded_mod.type()->getAttribute("subMod")->cast<InterfaceType>();
|
|
ASSERT_TRUE(submodType->is_module());
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|