[sigmoid] Support custom obj deserialization. (#133463)

Summary:
It seems we have multiple places deserializing torchbind objects. Moving the code around so that every load essentially share the same implementation.

Also added a test case "package_reader_testing" which load back the archive file in Python and eagerly validate the numerical result.

Test Plan: buck test mode/opt sigmoid/inference/test:e2e_test_cpu

Reviewed By: SherlockNoMad

Differential Revision: D61235770

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133463
Approved by: https://github.com/ydwu4
This commit is contained in:
Zhengxu Chen 2024-08-15 17:58:44 +00:00 committed by PyTorch MergeBot
parent 5ec9c0bc4a
commit 59b3f5911d
4 changed files with 86 additions and 1 deletions

View File

@ -2409,7 +2409,8 @@ def _save_jit_module_to_bytes(m: ScriptModule, extra_files: Dict[str, Any]) ->
def _get_module_info_from_flatbuffer(data: bytes): ...
def _jit_resolve_packet(op_name: str, *args, **kwargs) -> str: ...
def _swap_tensor_impl(t1: Tensor, t2: Tensor): ...
def _save_pickle(obj: Any) -> bytes: ...
def _pickle_save(obj: Any) -> bytes: ...
def _pickle_load_obj(bs: bytes) -> Any: ...
# Defined in torch/csrc/jit/runtime/static/init.cpp
def _jit_to_static_module(graph_or_module: Union[Graph,ScriptModule]) -> Any: ...

View File

@ -2443,6 +2443,12 @@ void initJitScriptBindings(PyObject* module) {
return py::bytes(bytes.data(), bytes.size());
});
m.def("_pickle_load_obj", [](const py::bytes& bytes) {
// https://github.com/pybind/pybind11/issues/2517
std::string buffer = bytes;
return torch::jit::pickle_load_obj(buffer);
});
initScriptDictBindings(module);
initScriptListBindings(module);
}

View File

@ -9,6 +9,35 @@
namespace torch::jit {
namespace {
c10::StrongTypePtr customClassResolver(const c10::QualifiedName& qn) {
at::TypePtr type = nullptr;
if (c10::QualifiedName("__torch__").isPrefixOf(qn)) {
type = torch::getCustomClass(qn.qualifiedName());
} else {
// This is a regular type, fall back to the default type parser
torch::jit::ScriptTypeParser parser;
type = parser.parseType(qn.qualifiedName());
return c10::StrongTypePtr(nullptr, std::move(type));
}
if (type == nullptr) {
TORCH_CHECK(
false,
"Couldn't resolve type '{}', did you forget to add its build dependency?",
qn.qualifiedName());
}
// Passing nullptr is a little bit sus, but should be fine:
// 1. The lifetime of the class type is not tied to a specific
// CompilationUnit
// but rather the global custom class registry.
// 2. We will not access the `cu_` field and immediately discard this
// StrongTypePtr post-deserialization.
return c10::StrongTypePtr(nullptr, std::move(type));
}
} // namespace
void pickle(
std::function<void(const char* data_start, size_t data_len)> writer,
const IValue& ivalue,
@ -80,6 +109,16 @@ size_t VectorReader::read(uint64_t pos, void* buf, size_t n, const char* what)
data_.data() + pos, data_.data() + pos + n, reinterpret_cast<char*>(buf));
return n;
}
size_t StringViewReader::read(
uint64_t pos,
void* buf,
size_t n,
const char* what) const {
std::copy(
data_.data() + pos, data_.data() + pos + n, reinterpret_cast<char*>(buf));
return n;
}
#endif
IValue pickle_load(const std::vector<char>& data) {
@ -103,6 +142,26 @@ IValue pickle_load(const std::vector<char>& data) {
#endif
};
// A specialized version of pickle_load that can load custom objects.
c10::IValue pickle_load_obj(std::string_view data) {
#ifndef C10_MOBILE
caffe2::serialize::PyTorchStreamReader reader(
std::make_unique<torch::jit::StringViewReader>(data));
return torch::jit::readArchiveAndTensors(
"data",
/*pickle_prefix=*/"",
/*tensor_prefix=*/"",
/*type_resolver=*/customClassResolver,
/*obj_loader=*/torch::jit::ObjLoaderFunc,
/*device=*/c10::nullopt,
reader);
#else
AT_ERROR(
"pickle_load not supported on mobile "
"(see https://github.com/pytorch/pytorch/pull/30108)");
#endif
}
IValue unpickle(
std::function<size_t(char*, size_t)> reader,
TypeResolver type_resolver,

View File

@ -60,6 +60,10 @@ TORCH_API std::vector<char> pickle_save(const IValue& ivalue);
/// `torch::pickle_save` in C++ or `torch.save` in Python
TORCH_API IValue pickle_load(const std::vector<char>& data);
/// Deserialize a `torch::IValue` from bytes produced by either
/// `torch::pickle_save` in C++ or `torch.save` in Python with custom object.
TORCH_API IValue pickle_load_obj(std::string_view data);
/// `reader` is a function that takes in a size to read from some pickled
/// binary. `reader` should remember where it last read, and return
/// the number of bytes read.
@ -117,5 +121,20 @@ class VectorReader : public caffe2::serialize::ReadAdapterInterface {
private:
std::vector<char> data_;
};
class StringViewReader : public caffe2::serialize::ReadAdapterInterface {
public:
StringViewReader(std::string_view data) : data_(data) {}
size_t size() const override {
return data_.size();
}
size_t read(uint64_t pos, void* buf, size_t n, const char* what)
const override;
private:
std::string_view data_;
};
#endif
} // namespace torch::jit