mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Use streams in JIT serialization, allow JIT serialization to/from buffer (#11932)
Summary: This PR replaces the use of `std::FILE` with `istream`/`ostream` for JIT serialization. It uses this mechanism to add the possibility to serialize to/from binary buffers, in addition to files, both in `libtorch` and from Python. `getExportImportCopy` in `test_jit.py` has been updated so that both file and buffer codepaths are exercised during tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/11932 Differential Revision: D10084303 Pulled By: apaszke fbshipit-source-id: b850801b3932922fa1dbac6fdaed5063d58bc20d
This commit is contained in:
parent
d291cf7de6
commit
5be0baefa2
|
|
@ -240,14 +240,10 @@ class JitTestCase(TestCase):
|
||||||
imported = torch.jit.load(f.name)
|
imported = torch.jit.load(f.name)
|
||||||
finally:
|
finally:
|
||||||
os.unlink(f.name)
|
os.unlink(f.name)
|
||||||
f = tempfile.NamedTemporaryFile(delete=False)
|
buffer = io.BytesIO()
|
||||||
try:
|
torch.jit.save(imported, buffer)
|
||||||
f.close()
|
buffer.seek(0)
|
||||||
imported.save(f.name)
|
return torch.jit.load(buffer)
|
||||||
imported = torch.jit.load(f.name)
|
|
||||||
finally:
|
|
||||||
os.unlink(f.name)
|
|
||||||
return imported
|
|
||||||
|
|
||||||
def assertGraphContains(self, graph, kind):
|
def assertGraphContains(self, graph, kind):
|
||||||
self.assertTrue(any(n.kind() == kind for n in graph.nodes()))
|
self.assertTrue(any(n.kind() == kind for n in graph.nodes()))
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
namespace torch { namespace jit {
|
namespace torch { namespace jit {
|
||||||
|
|
||||||
|
|
@ -425,7 +426,7 @@ void GraphEncoder::EncodeTensor(
|
||||||
class ModuleEncoder: public EncoderBase {
|
class ModuleEncoder: public EncoderBase {
|
||||||
public:
|
public:
|
||||||
ModuleEncoder(const script::Module &module,
|
ModuleEncoder(const script::Module &module,
|
||||||
const std::string &filename);
|
std::ostream& out);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void EncodeModule(onnx::GraphProto *graph_proto, const script::Module &module);
|
void EncodeModule(onnx::GraphProto *graph_proto, const script::Module &module);
|
||||||
|
|
@ -448,7 +449,7 @@ class ModuleEncoder: public EncoderBase {
|
||||||
|
|
||||||
virtual void EncodeTensor(onnx::TensorProto *tensor_proto,
|
virtual void EncodeTensor(onnx::TensorProto *tensor_proto,
|
||||||
const at::Tensor &tensor,
|
const at::Tensor &tensor,
|
||||||
const at::optional<std::string> external_ref) override;
|
const at::optional<std::string> external_ref = {}) override;
|
||||||
|
|
||||||
virtual void EncodeIntermediateValueInfo(onnx::GraphProto *graph_proto,
|
virtual void EncodeIntermediateValueInfo(onnx::GraphProto *graph_proto,
|
||||||
const Value* n) override;
|
const Value* n) override;
|
||||||
|
|
@ -462,7 +463,7 @@ class ModuleEncoder: public EncoderBase {
|
||||||
const TypePtr& type,
|
const TypePtr& type,
|
||||||
const std::string& name);
|
const std::string& name);
|
||||||
|
|
||||||
PyTorchFileWriter file_writer_;
|
PyTorchStreamWriter stream_writer_;
|
||||||
// Used to deduplicate tensor storages
|
// Used to deduplicate tensor storages
|
||||||
std::unordered_map<const void*, uint64_t> storage_dedup_map_;
|
std::unordered_map<const void*, uint64_t> storage_dedup_map_;
|
||||||
|
|
||||||
|
|
@ -475,9 +476,9 @@ class ModuleEncoder: public EncoderBase {
|
||||||
|
|
||||||
ModuleEncoder::ModuleEncoder(
|
ModuleEncoder::ModuleEncoder(
|
||||||
const script::Module &module,
|
const script::Module &module,
|
||||||
const std::string &filename)
|
std::ostream& out)
|
||||||
: EncoderBase(onnx_torch::OperatorExportTypes::RAW, false),
|
: EncoderBase(onnx_torch::OperatorExportTypes::RAW, false),
|
||||||
file_writer_(filename) {
|
stream_writer_(out) {
|
||||||
model_proto_.set_doc_string("THIS PROTO IS NOT STANDARD ONNX");
|
model_proto_.set_doc_string("THIS PROTO IS NOT STANDARD ONNX");
|
||||||
EncodeModule(model_proto_.mutable_graph(), module);
|
EncodeModule(model_proto_.mutable_graph(), module);
|
||||||
}
|
}
|
||||||
|
|
@ -586,7 +587,7 @@ void ModuleEncoder::EncodeModule(
|
||||||
EncodeParameters(graph_proto, module, "");
|
EncodeParameters(graph_proto, module, "");
|
||||||
EncodeMethods(graph_proto, module, "");
|
EncodeMethods(graph_proto, module, "");
|
||||||
auto str = model_proto_.SerializeAsString();
|
auto str = model_proto_.SerializeAsString();
|
||||||
file_writer_.writeRecord(str.data(), str.size());
|
stream_writer_.writeRecord(str.data(), str.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
void ModuleEncoder::EncodeParameters(
|
void ModuleEncoder::EncodeParameters(
|
||||||
|
|
@ -674,7 +675,7 @@ void ModuleEncoder::EncodeMethod(
|
||||||
void ModuleEncoder::EncodeTensor(
|
void ModuleEncoder::EncodeTensor(
|
||||||
onnx::TensorProto *tensor_proto,
|
onnx::TensorProto *tensor_proto,
|
||||||
const at::Tensor &tensor,
|
const at::Tensor &tensor,
|
||||||
const at::optional<std::string> external_ref = {}) {
|
const at::optional<std::string> external_ref) {
|
||||||
auto storage_ptr = tensor.storage().unsafeGetStorageImpl();
|
auto storage_ptr = tensor.storage().unsafeGetStorageImpl();
|
||||||
auto dedup_it = storage_dedup_map_.find(storage_ptr);
|
auto dedup_it = storage_dedup_map_.find(storage_ptr);
|
||||||
if (dedup_it != storage_dedup_map_.end()) {
|
if (dedup_it != storage_dedup_map_.end()) {
|
||||||
|
|
@ -693,7 +694,7 @@ void ModuleEncoder::EncodeTensor(
|
||||||
.cpu();
|
.cpu();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto record_number = file_writer_.writeRecord(
|
auto record_number = stream_writer_.writeRecord(
|
||||||
static_cast<char*>(t.storage().data()), t.type().elementSizeInBytes() * t.storage().size());
|
static_cast<char*>(t.storage().data()), t.type().elementSizeInBytes() * t.storage().size());
|
||||||
tensor_proto->add_int64_data(record_number);
|
tensor_proto->add_int64_data(record_number);
|
||||||
storage_dedup_map_[storage_ptr] = record_number;
|
storage_dedup_map_[storage_ptr] = record_number;
|
||||||
|
|
@ -919,8 +920,14 @@ std::tuple<std::string, RawDataExportMap> ExportGraph(
|
||||||
graph_encoder.get_raw_data_export_map());
|
graph_encoder.get_raw_data_export_map());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ExportModule(const script::Module& module, std::ostream& out) {
|
||||||
|
ModuleEncoder(module, out);
|
||||||
|
}
|
||||||
|
|
||||||
void ExportModule(const script::Module& module, const std::string &filename) {
|
void ExportModule(const script::Module& module, const std::string &filename) {
|
||||||
ModuleEncoder(module, filename);
|
std::ofstream out(filename, std::ios_base::binary);
|
||||||
|
|
||||||
|
ExportModule(module, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
}}
|
}}
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,8 @@
|
||||||
#include "torch/csrc/jit/script/module.h"
|
#include "torch/csrc/jit/script/module.h"
|
||||||
#include "torch/csrc/onnx/onnx.h"
|
#include "torch/csrc/onnx/onnx.h"
|
||||||
|
|
||||||
|
#include <ostream>
|
||||||
|
|
||||||
namespace torch { namespace jit {
|
namespace torch { namespace jit {
|
||||||
|
|
||||||
// This map is used to keep track of parameters that should be exported
|
// This map is used to keep track of parameters that should be exported
|
||||||
|
|
@ -34,6 +36,10 @@ TORCH_API std::string PrettyPrintExportedGraph(
|
||||||
= ::torch::onnx::OperatorExportTypes::ONNX,
|
= ::torch::onnx::OperatorExportTypes::ONNX,
|
||||||
bool google_printer = false);
|
bool google_printer = false);
|
||||||
|
|
||||||
|
TORCH_API void ExportModule(
|
||||||
|
const script::Module& module,
|
||||||
|
std::ostream& out);
|
||||||
|
|
||||||
TORCH_API void ExportModule(
|
TORCH_API void ExportModule(
|
||||||
const script::Module& module,
|
const script::Module& module,
|
||||||
const std::string& filename);
|
const std::string& filename);
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
namespace torch { namespace jit {
|
namespace torch { namespace jit {
|
||||||
|
|
||||||
|
|
@ -181,7 +182,7 @@ void DecoderBase::buildBlock(const onnx::GraphProto& graph_proto, Block* block,
|
||||||
class ModuleDecoder : DecoderBase {
|
class ModuleDecoder : DecoderBase {
|
||||||
public:
|
public:
|
||||||
ModuleDecoder(ModuleLookup module_lookup,
|
ModuleDecoder(ModuleLookup module_lookup,
|
||||||
const std::string& filename);
|
std::istream& in);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
virtual std::shared_ptr<Graph> buildGraph(const onnx::GraphProto& graph_proto) override;
|
virtual std::shared_ptr<Graph> buildGraph(const onnx::GraphProto& graph_proto) override;
|
||||||
|
|
@ -205,7 +206,7 @@ class ModuleDecoder : DecoderBase {
|
||||||
ModuleLookup module_lookup,
|
ModuleLookup module_lookup,
|
||||||
const std::string fullname);
|
const std::string fullname);
|
||||||
|
|
||||||
PyTorchFileReader file_reader_;
|
PyTorchStreamReader stream_reader_;
|
||||||
std::unordered_map<uint64_t, std::shared_ptr<at::Storage>> storage_map_;
|
std::unordered_map<uint64_t, std::shared_ptr<at::Storage>> storage_map_;
|
||||||
std::unordered_map<std::string, const onnx::TypeProto*> value_type_map_;
|
std::unordered_map<std::string, const onnx::TypeProto*> value_type_map_;
|
||||||
};
|
};
|
||||||
|
|
@ -319,7 +320,7 @@ at::Tensor ModuleDecoder::buildTensorCommon(
|
||||||
if (storage_it == storage_map_.end()) {
|
if (storage_it == storage_map_.end()) {
|
||||||
at::DataPtr storage_ptr;
|
at::DataPtr storage_ptr;
|
||||||
int64_t size;
|
int64_t size;
|
||||||
std::tie(storage_ptr, size) = file_reader_.getRecordWithKey(record_number);
|
std::tie(storage_ptr, size) = stream_reader_.getRecordWithKey(record_number);
|
||||||
auto storage = std::make_shared<at::Storage>(
|
auto storage = std::make_shared<at::Storage>(
|
||||||
at::CPU(type).typeMeta(),
|
at::CPU(type).typeMeta(),
|
||||||
std::move(storage_ptr),
|
std::move(storage_ptr),
|
||||||
|
|
@ -353,10 +354,10 @@ std::pair<std::shared_ptr<script::Module>, std::string> ModuleDecoder::parseFull
|
||||||
|
|
||||||
ModuleDecoder::ModuleDecoder(
|
ModuleDecoder::ModuleDecoder(
|
||||||
ModuleLookup module_lookup,
|
ModuleLookup module_lookup,
|
||||||
const std::string &filename) :
|
std::istream& in) :
|
||||||
file_reader_(filename) {
|
stream_reader_(in) {
|
||||||
auto model_proto = onnx::ModelProto();
|
auto model_proto = onnx::ModelProto();
|
||||||
auto record = file_reader_.getLastRecord();
|
auto record = stream_reader_.getLastRecord();
|
||||||
model_proto.ParsePartialFromArray(std::get<0>(record).get(), std::get<1>(record));
|
model_proto.ParsePartialFromArray(std::get<0>(record).get(), std::get<1>(record));
|
||||||
auto graph_proto = model_proto.graph();
|
auto graph_proto = model_proto.graph();
|
||||||
|
|
||||||
|
|
@ -397,11 +398,19 @@ ModuleDecoder::ModuleDecoder(
|
||||||
|
|
||||||
void import_ir_module(
|
void import_ir_module(
|
||||||
ModuleLookup module_lookup,
|
ModuleLookup module_lookup,
|
||||||
const std::string& filename) {
|
std::istream& in) {
|
||||||
ModuleDecoder(module_lookup, filename);
|
ModuleDecoder(module_lookup, in);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<script::Module> load(const std::string& filename) {
|
void import_ir_module(
|
||||||
|
ModuleLookup module_lookup,
|
||||||
|
const std::string& filename) {
|
||||||
|
std::ifstream in(filename, std::ios_base::binary);
|
||||||
|
|
||||||
|
ModuleDecoder(module_lookup, in);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<script::Module> load(std::istream& in) {
|
||||||
auto module = std::make_shared<script::Module>();
|
auto module = std::make_shared<script::Module>();
|
||||||
|
|
||||||
auto module_lookup = [&](const std::vector<std::string>& qualified_name) {
|
auto module_lookup = [&](const std::vector<std::string>& qualified_name) {
|
||||||
|
|
@ -414,7 +423,17 @@ std::shared_ptr<script::Module> load(const std::string& filename) {
|
||||||
}
|
}
|
||||||
return curr;
|
return curr;
|
||||||
};
|
};
|
||||||
ModuleDecoder(module_lookup, filename);
|
|
||||||
|
ModuleDecoder(module_lookup, in);
|
||||||
|
|
||||||
|
return module;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<script::Module> load(const std::string& filename) {
|
||||||
|
std::ifstream in(filename, std::ios_base::binary);
|
||||||
|
|
||||||
|
auto module = load(in);
|
||||||
|
|
||||||
return module;
|
return module;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,8 @@
|
||||||
#include "torch/csrc/jit/ir.h"
|
#include "torch/csrc/jit/ir.h"
|
||||||
#include "torch/csrc/jit/script/module.h"
|
#include "torch/csrc/jit/script/module.h"
|
||||||
|
|
||||||
|
#include <istream>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
||||||
|
|
@ -13,11 +15,18 @@ TORCH_API void import_ir_module(
|
||||||
ModuleLookup module_lookup,
|
ModuleLookup module_lookup,
|
||||||
const std::string& filename);
|
const std::string& filename);
|
||||||
|
|
||||||
|
TORCH_API void import_ir_module(
|
||||||
|
ModuleLookup module_lookup,
|
||||||
|
std::istream& in);
|
||||||
|
|
||||||
/// Loads a serialized `script::Module` from the given `filename`.
|
/// Loads a serialized `script::Module` from the given `filename`.
|
||||||
///
|
///
|
||||||
/// The file stored at the location given in `filename` must contain a
|
/// The file stored at the location given in `filename` must contain a
|
||||||
/// serialized `script::Module`, exported either via `ScriptModule.save()` in
|
/// serialized `script::Module`, exported either via `ScriptModule.save()` in
|
||||||
/// Python or `torch::jit::ExportModule` in C++.
|
/// Python or `torch::jit::ExportModule` in C++.
|
||||||
|
|
||||||
|
TORCH_API std::shared_ptr<script::Module> load(std::istream& in);
|
||||||
|
|
||||||
TORCH_API std::shared_ptr<script::Module> load(const std::string& filename);
|
TORCH_API std::shared_ptr<script::Module> load(const std::string& filename);
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
|
|
|
||||||
|
|
@ -227,7 +227,6 @@ void initJITBindings(PyObject *module) {
|
||||||
return createPyObjectForStack(std::move(stack));
|
return createPyObjectForStack(std::move(stack));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
py::class_<PyTorchFileWriter>(m, "PyTorchFileWriter")
|
py::class_<PyTorchFileWriter>(m, "PyTorchFileWriter")
|
||||||
.def(py::init<std::string>())
|
.def(py::init<std::string>())
|
||||||
.def("write_record", &PyTorchFileWriter::writeRecord)
|
.def("write_record", &PyTorchFileWriter::writeRecord)
|
||||||
|
|
|
||||||
|
|
@ -371,7 +371,14 @@ void initJitScriptBindings(PyObject* module) {
|
||||||
// public.
|
// public.
|
||||||
py::class_<Module, std::shared_ptr<Module>>(m, "ScriptModule")
|
py::class_<Module, std::shared_ptr<Module>>(m, "ScriptModule")
|
||||||
.def(py::init<>())
|
.def(py::init<>())
|
||||||
.def("save", &Module::save)
|
.def("save", [](std::shared_ptr<Module> m, const std::string& filename) {
|
||||||
|
m->save(filename);
|
||||||
|
})
|
||||||
|
.def("save_to_buffer", [](std::shared_ptr<Module> m) {
|
||||||
|
std::ostringstream buf;
|
||||||
|
m->save(buf);
|
||||||
|
return py::bytes(buf.str());
|
||||||
|
})
|
||||||
.def("_set_optimized", &Module::set_optimized)
|
.def("_set_optimized", &Module::set_optimized)
|
||||||
.def(
|
.def(
|
||||||
"_define",
|
"_define",
|
||||||
|
|
@ -534,7 +541,13 @@ void initJitScriptBindings(PyObject* module) {
|
||||||
});
|
});
|
||||||
|
|
||||||
m.def("merge_type_from_type_comment", &mergeTypesFromTypeComment);
|
m.def("merge_type_from_type_comment", &mergeTypesFromTypeComment);
|
||||||
m.def("import_ir_module", import_ir_module);
|
m.def("import_ir_module", [](ModuleLookup module_lookup, const std::string& filename) {
|
||||||
|
import_ir_module(module_lookup, filename);
|
||||||
|
});
|
||||||
|
m.def("import_ir_module_from_buffer", [](ModuleLookup module_lookup, const std::string& buffer) {
|
||||||
|
std::istringstream in(buffer);
|
||||||
|
import_ir_module(module_lookup, in);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace script
|
} // namespace script
|
||||||
|
|
|
||||||
|
|
@ -71,6 +71,10 @@ void Method::ensure_defined() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Module::save(std::ostream& out) {
|
||||||
|
ExportModule(*this, out);
|
||||||
|
}
|
||||||
|
|
||||||
void Module::save(const std::string& filename) {
|
void Module::save(const std::string& filename) {
|
||||||
ExportModule(*this, filename);
|
ExportModule(*this, filename);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <ostream>
|
||||||
|
|
||||||
// This file contains classes which assist in desugaring Python style
|
// This file contains classes which assist in desugaring Python style
|
||||||
// modules and their methods into flattened graphs which don't have any
|
// modules and their methods into flattened graphs which don't have any
|
||||||
|
|
@ -376,6 +377,8 @@ struct Module {
|
||||||
return get_method(method_name)({IValue(std::forward<Types>(args))...});
|
return get_method(method_name)({IValue(std::forward<Types>(args))...});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void save(std::ostream& out);
|
||||||
|
|
||||||
void save(const std::string& filename);
|
void save(const std::string& filename);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,9 @@
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <cerrno>
|
#include <cerrno>
|
||||||
|
#include <istream>
|
||||||
|
#include <ostream>
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
namespace torch { namespace jit {
|
namespace torch { namespace jit {
|
||||||
|
|
||||||
|
|
@ -75,25 +78,16 @@ namespace {
|
||||||
static constexpr uint64_t kFileFormatVersion = 0x1L;
|
static constexpr uint64_t kFileFormatVersion = 0x1L;
|
||||||
static constexpr uint8_t kPadValue = 0xEF;
|
static constexpr uint8_t kPadValue = 0xEF;
|
||||||
|
|
||||||
void wrapPErrorAndThrow(const std::string& msg) {
|
|
||||||
std::ostringstream oss;
|
|
||||||
oss << msg << " : " << strerror(errno);
|
|
||||||
throw std::runtime_error(oss.str());
|
|
||||||
}
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
class PyTorchFileReader {
|
class PyTorchStreamReader {
|
||||||
public:
|
public:
|
||||||
PyTorchFileReader(std::string filename) {
|
PyTorchStreamReader(std::istream& in_) : in(in_) {
|
||||||
fp = std::fopen(filename.c_str(), "rb");
|
|
||||||
if (!fp) {
|
|
||||||
wrapPErrorAndThrow("Couldn't open file for reading!");
|
|
||||||
}
|
|
||||||
// Store file size so we know when we're done reading because the f* APIs
|
// Store file size so we know when we're done reading because the f* APIs
|
||||||
// don't do a good job of that
|
// don't do a good job of that
|
||||||
std::fseek(fp, 0L, SEEK_END);
|
in.seekg(0L, in.end);
|
||||||
file_size = std::ftell(fp);
|
file_size = in.tellg();
|
||||||
std::fseek(fp, 0L, SEEK_SET);
|
in.seekg(0L);
|
||||||
readAndValidateFileHeader();
|
readAndValidateFileHeader();
|
||||||
// Do this now since we're reasonably sure this is actually a PyT file from
|
// Do this now since we're reasonably sure this is actually a PyT file from
|
||||||
// the header.
|
// the header.
|
||||||
|
|
@ -115,7 +109,7 @@ class PyTorchFileReader {
|
||||||
}
|
}
|
||||||
// Seek to the provided offset
|
// Seek to the provided offset
|
||||||
cursor = key;
|
cursor = key;
|
||||||
std::fseek(fp, cursor, SEEK_SET);
|
in.seekg(cursor);
|
||||||
auto tag = read64BitIntegerLittleEndian();
|
auto tag = read64BitIntegerLittleEndian();
|
||||||
if (tag != RecordTags::STORAGE) {
|
if (tag != RecordTags::STORAGE) {
|
||||||
throw std::runtime_error("Attempted to read a record of non-storage type");
|
throw std::runtime_error("Attempted to read a record of non-storage type");
|
||||||
|
|
@ -124,18 +118,16 @@ class PyTorchFileReader {
|
||||||
seekToNextAlignmentBoundary();
|
seekToNextAlignmentBoundary();
|
||||||
auto ptr = malloc(size);
|
auto ptr = malloc(size);
|
||||||
at::DataPtr retval(ptr, ptr, free, at::kCPU);
|
at::DataPtr retval(ptr, ptr, free, at::kCPU);
|
||||||
if (!std::fread(ptr, size, 1, fp)) {
|
|
||||||
wrapPErrorAndThrow("Failed to read data from record");
|
in.read((char*)ptr, size);
|
||||||
}
|
|
||||||
cursor += size;
|
cursor += size;
|
||||||
seekToNextAlignmentBoundary();
|
seekToNextAlignmentBoundary();
|
||||||
return std::tuple<at::DataPtr, size_t>(std::move(retval), size);
|
return std::tuple<at::DataPtr, size_t>(std::move(retval), size);
|
||||||
}
|
}
|
||||||
~PyTorchFileReader() {
|
~PyTorchStreamReader() {
|
||||||
std::fclose(fp);
|
|
||||||
}
|
}
|
||||||
private:
|
private:
|
||||||
FILE *fp;
|
std::istream& in;
|
||||||
size_t cursor = 0;
|
size_t cursor = 0;
|
||||||
size_t file_size;
|
size_t file_size;
|
||||||
size_t last_record_offset;
|
size_t last_record_offset;
|
||||||
|
|
@ -144,8 +136,9 @@ class PyTorchFileReader {
|
||||||
uint64_t read64BitIntegerLittleEndian() {
|
uint64_t read64BitIntegerLittleEndian() {
|
||||||
uint64_t retval;
|
uint64_t retval;
|
||||||
// TODO endian swap on platforms that need it?
|
// TODO endian swap on platforms that need it?
|
||||||
size_t read_bytes = std::fread(&retval, 1u, 8u, fp);
|
in.read(reinterpret_cast<char *>(&retval), 8);
|
||||||
if (read_bytes != 8u) {
|
std::streamsize read_bytes = in.gcount();
|
||||||
|
if (read_bytes != 8) {
|
||||||
std::ostringstream errmsg;
|
std::ostringstream errmsg;
|
||||||
errmsg << "Expected to read 8 bytes but got " << read_bytes;
|
errmsg << "Expected to read 8 bytes but got " << read_bytes;
|
||||||
throw std::runtime_error(errmsg.str());
|
throw std::runtime_error(errmsg.str());
|
||||||
|
|
@ -158,7 +151,7 @@ class PyTorchFileReader {
|
||||||
size_t next_offset = (cursor + kFieldAlignment) - (cursor % kFieldAlignment);
|
size_t next_offset = (cursor + kFieldAlignment) - (cursor % kFieldAlignment);
|
||||||
size_t pad_amount = next_offset - cursor;
|
size_t pad_amount = next_offset - cursor;
|
||||||
cursor += pad_amount;
|
cursor += pad_amount;
|
||||||
std::fseek(fp, cursor, SEEK_SET);
|
in.seekg(cursor);
|
||||||
}
|
}
|
||||||
|
|
||||||
// File format deserialization functions
|
// File format deserialization functions
|
||||||
|
|
@ -183,7 +176,7 @@ class PyTorchFileReader {
|
||||||
// Seek to location of file footer. We've already validated that the file
|
// Seek to location of file footer. We've already validated that the file
|
||||||
// length is a multiple of the alignment size
|
// length is a multiple of the alignment size
|
||||||
cursor = file_size - kFieldAlignment;
|
cursor = file_size - kFieldAlignment;
|
||||||
std::fseek(fp, cursor, SEEK_SET);
|
in.seekg(cursor);
|
||||||
auto tag = read64BitIntegerLittleEndian();
|
auto tag = read64BitIntegerLittleEndian();
|
||||||
if (tag != RecordTags::FOOTER) {
|
if (tag != RecordTags::FOOTER) {
|
||||||
throw std::runtime_error("File footer has wrong record type. Is this"
|
throw std::runtime_error("File footer has wrong record type. Is this"
|
||||||
|
|
@ -197,13 +190,9 @@ class PyTorchFileReader {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class PyTorchFileWriter {
|
class PyTorchStreamWriter {
|
||||||
public:
|
public:
|
||||||
PyTorchFileWriter(const std::string& filename) {
|
PyTorchStreamWriter(std::ostream& out_) : out(out_) {
|
||||||
fp = std::fopen(filename.c_str(), "wb");
|
|
||||||
if (!fp) {
|
|
||||||
wrapPErrorAndThrow("Unable to open PyTorch file for writing!");
|
|
||||||
}
|
|
||||||
writeFileHeader();
|
writeFileHeader();
|
||||||
// In the case that we do not write any records into this file, the last
|
// In the case that we do not write any records into this file, the last
|
||||||
// record index written into the footer will point to the footer itself.
|
// record index written into the footer will point to the footer itself.
|
||||||
|
|
@ -224,15 +213,14 @@ class PyTorchFileWriter {
|
||||||
JIT_ASSERT(!finalized);
|
JIT_ASSERT(!finalized);
|
||||||
writeFileFooter();
|
writeFileFooter();
|
||||||
finalized = true;
|
finalized = true;
|
||||||
std::fclose(fp);
|
|
||||||
}
|
}
|
||||||
~PyTorchFileWriter() {
|
~PyTorchStreamWriter() {
|
||||||
if (!finalized) {
|
if (!finalized) {
|
||||||
writeEndOfFile();
|
writeEndOfFile();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
private:
|
private:
|
||||||
FILE *fp;
|
std::ostream& out;
|
||||||
size_t cursor = 0;
|
size_t cursor = 0;
|
||||||
bool finalized = false;
|
bool finalized = false;
|
||||||
size_t last_record_idx = 0;
|
size_t last_record_idx = 0;
|
||||||
|
|
@ -240,17 +228,13 @@ class PyTorchFileWriter {
|
||||||
// Utility functions
|
// Utility functions
|
||||||
void write64BitIntegerLittleEndian(const uint64_t value) {
|
void write64BitIntegerLittleEndian(const uint64_t value) {
|
||||||
// TODO endian swap on platforms that need it?
|
// TODO endian swap on platforms that need it?
|
||||||
if (!std::fwrite(&value, 8u, 1u, fp)) {
|
out.write(reinterpret_cast<const char *>(&value), 8);
|
||||||
wrapPErrorAndThrow("Unable to write to file!");
|
|
||||||
}
|
|
||||||
cursor += 8u;
|
cursor += 8u;
|
||||||
}
|
}
|
||||||
|
|
||||||
void writePad(const size_t num_bytes) {
|
void writePad(const size_t num_bytes) {
|
||||||
static std::vector<char> pad_buffer(kPadValue, kFieldAlignment);
|
static std::vector<char> pad_buffer(kPadValue, kFieldAlignment);
|
||||||
if (!std::fwrite(pad_buffer.data(), num_bytes, 1u, fp)) {
|
out.write(pad_buffer.data(), num_bytes);
|
||||||
wrapPErrorAndThrow("Unable to write to file!");
|
|
||||||
}
|
|
||||||
cursor += num_bytes;
|
cursor += num_bytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -261,9 +245,7 @@ class PyTorchFileWriter {
|
||||||
}
|
}
|
||||||
|
|
||||||
void writeBuffer(const char* data, size_t size) {
|
void writeBuffer(const char* data, size_t size) {
|
||||||
if (!std::fwrite(data, size, 1u, fp)) {
|
out.write(data, size);
|
||||||
wrapPErrorAndThrow("Unable to write to file!");
|
|
||||||
}
|
|
||||||
cursor += size;
|
cursor += size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -281,5 +263,43 @@ class PyTorchFileWriter {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class PyTorchFileReader {
|
||||||
|
public:
|
||||||
|
PyTorchFileReader(const std::string& filename) :
|
||||||
|
in(filename, std::ios_base::binary),
|
||||||
|
stream_reader(in) {}
|
||||||
|
|
||||||
|
std::tuple<at::DataPtr, size_t> getLastRecord() {
|
||||||
|
return stream_reader.getLastRecord();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<at::DataPtr, size_t> getRecordWithKey(uint64_t key) {
|
||||||
|
return stream_reader.getRecordWithKey(key);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::ifstream in;
|
||||||
|
PyTorchStreamReader stream_reader;
|
||||||
|
};
|
||||||
|
|
||||||
|
class PyTorchFileWriter {
|
||||||
|
public:
|
||||||
|
PyTorchFileWriter(const std::string& filename) :
|
||||||
|
out(filename, std::ios_base::binary),
|
||||||
|
stream_writer(out) {}
|
||||||
|
|
||||||
|
uint64_t writeRecord(const char* data, size_t size) {
|
||||||
|
return stream_writer.writeRecord(data, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
void writeEndOfFile() {
|
||||||
|
stream_writer.writeEndOfFile();
|
||||||
|
out.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::ofstream out;
|
||||||
|
PyTorchStreamWriter stream_writer;
|
||||||
|
};
|
||||||
|
|
||||||
}} // namespace torch::jit
|
}} // namespace torch::jit
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,8 @@ import copy
|
||||||
import numbers
|
import numbers
|
||||||
import collections
|
import collections
|
||||||
import re
|
import re
|
||||||
|
if sys.version_info[0] > 2:
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
|
||||||
def _parse_env(name, default, true_message, false_message):
|
def _parse_env(name, default, true_message, false_message):
|
||||||
|
|
@ -58,19 +60,27 @@ def scope(scope_name):
|
||||||
tracing_state.pop_scope()
|
tracing_state.pop_scope()
|
||||||
|
|
||||||
|
|
||||||
def load(filename):
|
def load(f):
|
||||||
r"""
|
r"""
|
||||||
Load a ``ScriptModule`` previously saved with :func:`save <torch.jit.ScriptModule.save>`
|
Load a ``ScriptModule`` previously saved with :func:`save <torch.jit.save>`
|
||||||
|
|
||||||
.. DANGER::
|
.. DANGER::
|
||||||
All previously saved modules, no matter their device, are always loaded onto the CPU.
|
All previously saved modules, no matter their device, are always loaded onto the CPU.
|
||||||
This is different from :func:`torch.load`'s semantics and may change in the future.
|
This is different from :func:`torch.load`'s semantics and may change in the future.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
filename (string): the file to load
|
f: a file-like object (has to implement read, readline, tell, and seek),
|
||||||
|
or a string containing a file name
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A ``ScriptModule`` object.
|
A ``ScriptModule`` object.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> torch.jit.load('scriptmodule.pt')
|
||||||
|
# Load ScriptModule from io.BytesIO object
|
||||||
|
>>> with open('scriptmodule.pt', 'rb') as f:
|
||||||
|
buffer = io.BytesIO(f.read())
|
||||||
|
>>> torch.jit.load(buffer)
|
||||||
"""
|
"""
|
||||||
m = ScriptModule()
|
m = ScriptModule()
|
||||||
|
|
||||||
|
|
@ -82,10 +92,48 @@ def load(filename):
|
||||||
curr = getattr(curr, name)
|
curr = getattr(curr, name)
|
||||||
return curr
|
return curr
|
||||||
|
|
||||||
torch._C.import_ir_module(module_lookup, filename)
|
if isinstance(f, str) or \
|
||||||
|
(sys.version_info[0] == 2 and isinstance(f, unicode)) or \
|
||||||
|
(sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
|
||||||
|
torch._C.import_ir_module(module_lookup, f)
|
||||||
|
else:
|
||||||
|
torch._C.import_ir_module_from_buffer(module_lookup, f.read())
|
||||||
return m
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
def save(m, f):
|
||||||
|
"""
|
||||||
|
Saves a ScriptModule to a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
m: a ScriptModule to save
|
||||||
|
f: a file-like object (has to implement write and flush) or a string
|
||||||
|
containing a file name
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
If you are using Python 2, torch.save does NOT support StringIO.StringIO
|
||||||
|
as a valid file-like object. This is because the write method should return
|
||||||
|
the number of bytes written; StringIO.write() does not do this.
|
||||||
|
|
||||||
|
Please use something like io.BytesIO instead.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> m = torch.jit.ScriptModule()
|
||||||
|
>>> # Save to file
|
||||||
|
>>> torch.jit.save(m, 'scriptmodule.pt')
|
||||||
|
>>> # Save to io.BytesIO buffer
|
||||||
|
>>> buffer = io.BytesIO()
|
||||||
|
>>> torch.jit.save(m, buffer)
|
||||||
|
"""
|
||||||
|
if isinstance(f, str) or \
|
||||||
|
(sys.version_info[0] == 2 and isinstance(f, unicode)) or \
|
||||||
|
(sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
|
||||||
|
m.save(f)
|
||||||
|
else:
|
||||||
|
ret = m.save_to_buffer()
|
||||||
|
f.write(ret)
|
||||||
|
|
||||||
|
|
||||||
def get_trace_graph(f, args=(), kwargs=None):
|
def get_trace_graph(f, args=(), kwargs=None):
|
||||||
"""
|
"""
|
||||||
Trace a function or model, returning a tuple consisting of the both the
|
Trace a function or model, returning a tuple consisting of the both the
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user