mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[sigmoid] Support custom obj deserialization. (#133463)
Summary: It seems we have multiple places deserializing torchbind objects. Moving the code around so that every load essentially share the same implementation. Also added a test case "package_reader_testing" which load back the archive file in Python and eagerly validate the numerical result. Test Plan: buck test mode/opt sigmoid/inference/test:e2e_test_cpu Reviewed By: SherlockNoMad Differential Revision: D61235770 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133463 Approved by: https://github.com/ydwu4
This commit is contained in:
parent
5ec9c0bc4a
commit
59b3f5911d
|
|
@ -2409,7 +2409,8 @@ def _save_jit_module_to_bytes(m: ScriptModule, extra_files: Dict[str, Any]) ->
|
|||
def _get_module_info_from_flatbuffer(data: bytes): ...
|
||||
def _jit_resolve_packet(op_name: str, *args, **kwargs) -> str: ...
|
||||
def _swap_tensor_impl(t1: Tensor, t2: Tensor): ...
|
||||
def _save_pickle(obj: Any) -> bytes: ...
|
||||
def _pickle_save(obj: Any) -> bytes: ...
|
||||
def _pickle_load_obj(bs: bytes) -> Any: ...
|
||||
|
||||
# Defined in torch/csrc/jit/runtime/static/init.cpp
|
||||
def _jit_to_static_module(graph_or_module: Union[Graph,ScriptModule]) -> Any: ...
|
||||
|
|
|
|||
|
|
@ -2443,6 +2443,12 @@ void initJitScriptBindings(PyObject* module) {
|
|||
return py::bytes(bytes.data(), bytes.size());
|
||||
});
|
||||
|
||||
m.def("_pickle_load_obj", [](const py::bytes& bytes) {
|
||||
// https://github.com/pybind/pybind11/issues/2517
|
||||
std::string buffer = bytes;
|
||||
return torch::jit::pickle_load_obj(buffer);
|
||||
});
|
||||
|
||||
initScriptDictBindings(module);
|
||||
initScriptListBindings(module);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,35 @@
|
|||
|
||||
namespace torch::jit {
|
||||
|
||||
namespace {
|
||||
|
||||
c10::StrongTypePtr customClassResolver(const c10::QualifiedName& qn) {
|
||||
at::TypePtr type = nullptr;
|
||||
if (c10::QualifiedName("__torch__").isPrefixOf(qn)) {
|
||||
type = torch::getCustomClass(qn.qualifiedName());
|
||||
} else {
|
||||
// This is a regular type, fall back to the default type parser
|
||||
torch::jit::ScriptTypeParser parser;
|
||||
type = parser.parseType(qn.qualifiedName());
|
||||
return c10::StrongTypePtr(nullptr, std::move(type));
|
||||
}
|
||||
if (type == nullptr) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Couldn't resolve type '{}', did you forget to add its build dependency?",
|
||||
qn.qualifiedName());
|
||||
}
|
||||
// Passing nullptr is a little bit sus, but should be fine:
|
||||
// 1. The lifetime of the class type is not tied to a specific
|
||||
// CompilationUnit
|
||||
// but rather the global custom class registry.
|
||||
// 2. We will not access the `cu_` field and immediately discard this
|
||||
// StrongTypePtr post-deserialization.
|
||||
return c10::StrongTypePtr(nullptr, std::move(type));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void pickle(
|
||||
std::function<void(const char* data_start, size_t data_len)> writer,
|
||||
const IValue& ivalue,
|
||||
|
|
@ -80,6 +109,16 @@ size_t VectorReader::read(uint64_t pos, void* buf, size_t n, const char* what)
|
|||
data_.data() + pos, data_.data() + pos + n, reinterpret_cast<char*>(buf));
|
||||
return n;
|
||||
}
|
||||
|
||||
size_t StringViewReader::read(
|
||||
uint64_t pos,
|
||||
void* buf,
|
||||
size_t n,
|
||||
const char* what) const {
|
||||
std::copy(
|
||||
data_.data() + pos, data_.data() + pos + n, reinterpret_cast<char*>(buf));
|
||||
return n;
|
||||
}
|
||||
#endif
|
||||
|
||||
IValue pickle_load(const std::vector<char>& data) {
|
||||
|
|
@ -103,6 +142,26 @@ IValue pickle_load(const std::vector<char>& data) {
|
|||
#endif
|
||||
};
|
||||
|
||||
// A specialized version of pickle_load that can load custom objects.
|
||||
c10::IValue pickle_load_obj(std::string_view data) {
|
||||
#ifndef C10_MOBILE
|
||||
caffe2::serialize::PyTorchStreamReader reader(
|
||||
std::make_unique<torch::jit::StringViewReader>(data));
|
||||
return torch::jit::readArchiveAndTensors(
|
||||
"data",
|
||||
/*pickle_prefix=*/"",
|
||||
/*tensor_prefix=*/"",
|
||||
/*type_resolver=*/customClassResolver,
|
||||
/*obj_loader=*/torch::jit::ObjLoaderFunc,
|
||||
/*device=*/c10::nullopt,
|
||||
reader);
|
||||
#else
|
||||
AT_ERROR(
|
||||
"pickle_load not supported on mobile "
|
||||
"(see https://github.com/pytorch/pytorch/pull/30108)");
|
||||
#endif
|
||||
}
|
||||
|
||||
IValue unpickle(
|
||||
std::function<size_t(char*, size_t)> reader,
|
||||
TypeResolver type_resolver,
|
||||
|
|
|
|||
|
|
@ -60,6 +60,10 @@ TORCH_API std::vector<char> pickle_save(const IValue& ivalue);
|
|||
/// `torch::pickle_save` in C++ or `torch.save` in Python
|
||||
TORCH_API IValue pickle_load(const std::vector<char>& data);
|
||||
|
||||
/// Deserialize a `torch::IValue` from bytes produced by either
|
||||
/// `torch::pickle_save` in C++ or `torch.save` in Python with custom object.
|
||||
TORCH_API IValue pickle_load_obj(std::string_view data);
|
||||
|
||||
/// `reader` is a function that takes in a size to read from some pickled
|
||||
/// binary. `reader` should remember where it last read, and return
|
||||
/// the number of bytes read.
|
||||
|
|
@ -117,5 +121,20 @@ class VectorReader : public caffe2::serialize::ReadAdapterInterface {
|
|||
private:
|
||||
std::vector<char> data_;
|
||||
};
|
||||
|
||||
class StringViewReader : public caffe2::serialize::ReadAdapterInterface {
|
||||
public:
|
||||
StringViewReader(std::string_view data) : data_(data) {}
|
||||
|
||||
size_t size() const override {
|
||||
return data_.size();
|
||||
}
|
||||
|
||||
size_t read(uint64_t pos, void* buf, size_t n, const char* what)
|
||||
const override;
|
||||
|
||||
private:
|
||||
std::string_view data_;
|
||||
};
|
||||
#endif
|
||||
} // namespace torch::jit
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user