pytorch/test/cpp/jit/test_save_load.cpp
Zachary DeVito 0e3389dced Fix circular deps in loading (#26758)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26758

This PR changes the order in which we import classes and functions so
that is is no longer necessary for them to defined in order in a file,
or for there to be proper import statements in the exported file.

Actually importing a function/class now is driven by the need to resolve
the entity during unpickling, type resolution, or value resolution.

While this should allow significant simplification to the code that
serializes classes, this work has not been done yet in order to avoid
inevitable forward compat issues in the transition period.

Notes:
* Individual functions have been replaced with a SourceImporter object
  that exposes a resolveType method. This method loads the type if
  it has not been loaded yet, potentially parsing  (but not loading)
  the file it exists in if that file hasn't been parsed yet.
* Some legacy functionality needed to be added as a method to this object
  since the old format still used some of this logic for class resolution.

Test Plan: Imported from OSS

Differential Revision: D17558989

Pulled By: zdevito

fbshipit-source-id: 7eae3470bcbd388c4de463e3462d527776ed46c6
2019-09-26 11:39:16 -07:00

78 lines
1.9 KiB
C++

#include <test/cpp/jit/test_base.h>
#include <test/cpp/jit/test_utils.h>
#include <sstream>
#include <torch/csrc/jit/export.h>
#include <torch/csrc/jit/import.h>
#include <torch/csrc/jit/import_source.h>
#include <torch/torch.h>
namespace torch {
namespace jit {
using namespace script;
void testSaveExtraFilesHook() {
// no secrets
{
std::stringstream ss;
{
Module m("__torch__.m");
ExtraFilesMap extra;
extra["metadata.json"] = "abc";
m.save(ss, extra);
}
ss.seekg(0);
{
ExtraFilesMap extra;
extra["metadata.json"] = "";
extra["secret.json"] = "";
jit::load(ss, c10::nullopt, extra);
ASSERT_EQ(extra["metadata.json"], "abc");
ASSERT_EQ(extra["secret.json"], "");
}
}
// some secret
{
std::stringstream ss;
{
SetExportModuleExtraFilesHook([](const Module&) -> ExtraFilesMap {
return {{"secret.json", "topsecret"}};
});
Module m("__torch__.m");
ExtraFilesMap extra;
extra["metadata.json"] = "abc";
m.save(ss, extra);
SetExportModuleExtraFilesHook(nullptr);
}
ss.seekg(0);
{
ExtraFilesMap extra;
extra["metadata.json"] = "";
extra["secret.json"] = "";
jit::load(ss, c10::nullopt, extra);
ASSERT_EQ(extra["metadata.json"], "abc");
ASSERT_EQ(extra["secret.json"], "topsecret");
}
}
}
static const auto pretty_printed = R"JIT(
op_version_set = 1000
def foo(x: Tensor,
y: Tensor) -> Tensor:
_0 = torch.add(torch.mul(x, 2), y, alpha=1)
return _0
)JIT";
void testImportTooNew() {
Module m("__torch__.m");
const std::vector<at::Tensor> constant_table;
auto src = std::make_shared<Source>(pretty_printed);
SourceImporter si(m.class_compilation_unit(), &constant_table, nullptr);
ASSERT_ANY_THROW(si.LEGACY_import_methods(m, src));
}
} // namespace jit
} // namespace torch