mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/28828 This updates torch::script::Module to more closely match the behavior of nn.Module. In particular, it implements the (optionally recurisive) iterators that retrieve submodules, parameters, and buffers and makes their names match the python versions. This also removes the individual accessors for Parameter, Module, Buffer, etc. and replaces them with a single `attr` function which is equivalent to writing `a.foo` in Python (`setattr` emulates `a.foo = v`). As we build out the user-facing API for TorchScript values this will end up matching how an attribute is accessed on general objects. This PR preservers the python bindings for script::Module by emulating the old API at the binding level. A followup will clean up the usage to more directly match the C++ API. Test Plan: Imported from OSS Differential Revision: D18197611 Pulled By: zdevito fbshipit-source-id: 7ee4dcbb258605d1c988314b05d938423f1ccee5
82 lines
2.2 KiB
C++
82 lines
2.2 KiB
C++
|
|
#include <test/cpp/jit/test_base.h>
|
|
#include <test/cpp/jit/test_utils.h>
|
|
|
|
#include <ATen/core/qualified_name.h>
|
|
#include <torch/csrc/jit/import.h>
|
|
#include <torch/csrc/jit/import_source.h>
|
|
#include <torch/csrc/jit/script/resolver.h>
|
|
#include <torch/torch.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
using namespace torch::jit::script;
|
|
|
|
static const std::vector<std::string> subMethodSrcs = {R"JIT(
|
|
def one(self, x: Tensor, y: Tensor) -> Tensor:
|
|
return x + y + 1
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return x
|
|
)JIT"};
|
|
static const auto parentForward = R"JIT(
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return self.subMod.forward(x)
|
|
)JIT";
|
|
|
|
static const auto moduleInterfaceSrc = R"JIT(
|
|
class OneForward(ModuleInterface):
|
|
def one(self, x: Tensor, y: Tensor) -> Tensor:
|
|
pass
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
pass
|
|
)JIT";
|
|
|
|
static void import_libs(
|
|
std::shared_ptr<CompilationUnit> cu,
|
|
const std::string& class_name,
|
|
const std::shared_ptr<Source>& src,
|
|
const std::vector<at::Tensor>& tensor_table) {
|
|
SourceImporter si(
|
|
cu,
|
|
&tensor_table,
|
|
[&](const std::string& name) -> std::shared_ptr<Source> { return src; },
|
|
/*version=*/2);
|
|
si.loadNamedType(QualifiedName(class_name));
|
|
}
|
|
|
|
void testModuleInterfaceSerialization() {
|
|
auto cu = std::make_shared<CompilationUnit>();
|
|
Module parentMod("parentMod", cu);
|
|
Module subMod("subMod", cu);
|
|
|
|
std::vector<at::Tensor> constantTable;
|
|
import_libs(
|
|
cu,
|
|
"__torch__.OneForward",
|
|
std::make_shared<Source>(moduleInterfaceSrc),
|
|
constantTable);
|
|
|
|
for (const std::string& method : subMethodSrcs) {
|
|
subMod.define(method, nativeResolver());
|
|
}
|
|
parentMod.register_attribute(
|
|
"subMod",
|
|
cu->get_interface("__torch__.OneForward"),
|
|
subMod.module_object(),
|
|
/*is_parameter=*/false);
|
|
parentMod.define(parentForward, nativeResolver());
|
|
ASSERT_TRUE(parentMod.hasattr("subMod"));
|
|
std::stringstream ss;
|
|
parentMod.save(ss);
|
|
Module reloaded_mod = jit::load(ss);
|
|
ASSERT_TRUE(reloaded_mod.hasattr("subMod"));
|
|
InterfaceTypePtr submodType =
|
|
reloaded_mod.type()->getAttribute("subMod")->cast<InterfaceType>();
|
|
ASSERT_TRUE(submodType->is_module());
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|