Use streams in JIT serialization, allow JIT serialization to/from buffer (#11932)

Summary:
This PR replaces the use of `std::FILE` with `istream`/`ostream` for JIT serialization.
It uses this mechanism to add the possibility to serialize to/from binary buffers, in addition to files, both in `libtorch` and from Python.

`getExportImportCopy` in `test_jit.py` has been updated so that both file and buffer codepaths are exercised during tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11932

Differential Revision: D10084303

Pulled By: apaszke

fbshipit-source-id: b850801b3932922fa1dbac6fdaed5063d58bc20d
This commit is contained in:
Luca Antiga 2018-09-28 07:41:26 -07:00 committed by Facebook Github Bot
parent d291cf7de6
commit 5be0baefa2
11 changed files with 201 additions and 77 deletions

View File

@ -240,14 +240,10 @@ class JitTestCase(TestCase):
imported = torch.jit.load(f.name)
finally:
os.unlink(f.name)
f = tempfile.NamedTemporaryFile(delete=False)
try:
f.close()
imported.save(f.name)
imported = torch.jit.load(f.name)
finally:
os.unlink(f.name)
return imported
buffer = io.BytesIO()
torch.jit.save(imported, buffer)
buffer.seek(0)
return torch.jit.load(buffer)
def assertGraphContains(self, graph, kind):
self.assertTrue(any(n.kind() == kind for n in graph.nodes()))

View File

@ -15,6 +15,7 @@
#include <vector>
#include <string>
#include <sstream>
#include <fstream>
namespace torch { namespace jit {
@ -425,7 +426,7 @@ void GraphEncoder::EncodeTensor(
class ModuleEncoder: public EncoderBase {
public:
ModuleEncoder(const script::Module &module,
const std::string &filename);
std::ostream& out);
private:
void EncodeModule(onnx::GraphProto *graph_proto, const script::Module &module);
@ -448,7 +449,7 @@ class ModuleEncoder: public EncoderBase {
virtual void EncodeTensor(onnx::TensorProto *tensor_proto,
const at::Tensor &tensor,
const at::optional<std::string> external_ref) override;
const at::optional<std::string> external_ref = {}) override;
virtual void EncodeIntermediateValueInfo(onnx::GraphProto *graph_proto,
const Value* n) override;
@ -462,7 +463,7 @@ class ModuleEncoder: public EncoderBase {
const TypePtr& type,
const std::string& name);
PyTorchFileWriter file_writer_;
PyTorchStreamWriter stream_writer_;
// Used to deduplicate tensor storages
std::unordered_map<const void*, uint64_t> storage_dedup_map_;
@ -475,9 +476,9 @@ class ModuleEncoder: public EncoderBase {
ModuleEncoder::ModuleEncoder(
const script::Module &module,
const std::string &filename)
std::ostream& out)
: EncoderBase(onnx_torch::OperatorExportTypes::RAW, false),
file_writer_(filename) {
stream_writer_(out) {
model_proto_.set_doc_string("THIS PROTO IS NOT STANDARD ONNX");
EncodeModule(model_proto_.mutable_graph(), module);
}
@ -586,7 +587,7 @@ void ModuleEncoder::EncodeModule(
EncodeParameters(graph_proto, module, "");
EncodeMethods(graph_proto, module, "");
auto str = model_proto_.SerializeAsString();
file_writer_.writeRecord(str.data(), str.size());
stream_writer_.writeRecord(str.data(), str.size());
}
void ModuleEncoder::EncodeParameters(
@ -674,7 +675,7 @@ void ModuleEncoder::EncodeMethod(
void ModuleEncoder::EncodeTensor(
onnx::TensorProto *tensor_proto,
const at::Tensor &tensor,
const at::optional<std::string> external_ref = {}) {
const at::optional<std::string> external_ref) {
auto storage_ptr = tensor.storage().unsafeGetStorageImpl();
auto dedup_it = storage_dedup_map_.find(storage_ptr);
if (dedup_it != storage_dedup_map_.end()) {
@ -693,7 +694,7 @@ void ModuleEncoder::EncodeTensor(
.cpu();
}
auto record_number = file_writer_.writeRecord(
auto record_number = stream_writer_.writeRecord(
static_cast<char*>(t.storage().data()), t.type().elementSizeInBytes() * t.storage().size());
tensor_proto->add_int64_data(record_number);
storage_dedup_map_[storage_ptr] = record_number;
@ -919,8 +920,14 @@ std::tuple<std::string, RawDataExportMap> ExportGraph(
graph_encoder.get_raw_data_export_map());
}
void ExportModule(const script::Module& module, std::ostream& out) {
ModuleEncoder(module, out);
}
void ExportModule(const script::Module& module, const std::string &filename) {
ModuleEncoder(module, filename);
std::ofstream out(filename, std::ios_base::binary);
ExportModule(module, out);
}
}}

View File

@ -4,6 +4,8 @@
#include "torch/csrc/jit/script/module.h"
#include "torch/csrc/onnx/onnx.h"
#include <ostream>
namespace torch { namespace jit {
// This map is used to keep track of parameters that should be exported
@ -34,6 +36,10 @@ TORCH_API std::string PrettyPrintExportedGraph(
= ::torch::onnx::OperatorExportTypes::ONNX,
bool google_printer = false);
TORCH_API void ExportModule(
const script::Module& module,
std::ostream& out);
TORCH_API void ExportModule(
const script::Module& module,
const std::string& filename);

View File

@ -11,6 +11,7 @@
#include <unordered_map>
#include <vector>
#include <string>
#include <fstream>
namespace torch { namespace jit {
@ -181,7 +182,7 @@ void DecoderBase::buildBlock(const onnx::GraphProto& graph_proto, Block* block,
class ModuleDecoder : DecoderBase {
public:
ModuleDecoder(ModuleLookup module_lookup,
const std::string& filename);
std::istream& in);
private:
virtual std::shared_ptr<Graph> buildGraph(const onnx::GraphProto& graph_proto) override;
@ -205,7 +206,7 @@ class ModuleDecoder : DecoderBase {
ModuleLookup module_lookup,
const std::string fullname);
PyTorchFileReader file_reader_;
PyTorchStreamReader stream_reader_;
std::unordered_map<uint64_t, std::shared_ptr<at::Storage>> storage_map_;
std::unordered_map<std::string, const onnx::TypeProto*> value_type_map_;
};
@ -319,7 +320,7 @@ at::Tensor ModuleDecoder::buildTensorCommon(
if (storage_it == storage_map_.end()) {
at::DataPtr storage_ptr;
int64_t size;
std::tie(storage_ptr, size) = file_reader_.getRecordWithKey(record_number);
std::tie(storage_ptr, size) = stream_reader_.getRecordWithKey(record_number);
auto storage = std::make_shared<at::Storage>(
at::CPU(type).typeMeta(),
std::move(storage_ptr),
@ -353,10 +354,10 @@ std::pair<std::shared_ptr<script::Module>, std::string> ModuleDecoder::parseFull
ModuleDecoder::ModuleDecoder(
ModuleLookup module_lookup,
const std::string &filename) :
file_reader_(filename) {
std::istream& in) :
stream_reader_(in) {
auto model_proto = onnx::ModelProto();
auto record = file_reader_.getLastRecord();
auto record = stream_reader_.getLastRecord();
model_proto.ParsePartialFromArray(std::get<0>(record).get(), std::get<1>(record));
auto graph_proto = model_proto.graph();
@ -397,11 +398,19 @@ ModuleDecoder::ModuleDecoder(
void import_ir_module(
ModuleLookup module_lookup,
const std::string& filename) {
ModuleDecoder(module_lookup, filename);
std::istream& in) {
ModuleDecoder(module_lookup, in);
}
std::shared_ptr<script::Module> load(const std::string& filename) {
void import_ir_module(
ModuleLookup module_lookup,
const std::string& filename) {
std::ifstream in(filename, std::ios_base::binary);
ModuleDecoder(module_lookup, in);
}
std::shared_ptr<script::Module> load(std::istream& in) {
auto module = std::make_shared<script::Module>();
auto module_lookup = [&](const std::vector<std::string>& qualified_name) {
@ -414,7 +423,17 @@ std::shared_ptr<script::Module> load(const std::string& filename) {
}
return curr;
};
ModuleDecoder(module_lookup, filename);
ModuleDecoder(module_lookup, in);
return module;
}
std::shared_ptr<script::Module> load(const std::string& filename) {
std::ifstream in(filename, std::ios_base::binary);
auto module = load(in);
return module;
}

View File

@ -3,6 +3,8 @@
#include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/script/module.h"
#include <istream>
namespace torch {
namespace jit {
@ -13,11 +15,18 @@ TORCH_API void import_ir_module(
ModuleLookup module_lookup,
const std::string& filename);
TORCH_API void import_ir_module(
ModuleLookup module_lookup,
std::istream& in);
/// Loads a serialized `script::Module` from the given `filename`.
///
/// The file stored at the location given in `filename` must contain a
/// serialized `script::Module`, exported either via `ScriptModule.save()` in
/// Python or `torch::jit::ExportModule` in C++.
TORCH_API std::shared_ptr<script::Module> load(std::istream& in);
TORCH_API std::shared_ptr<script::Module> load(const std::string& filename);
} // namespace jit

View File

@ -227,7 +227,6 @@ void initJITBindings(PyObject *module) {
return createPyObjectForStack(std::move(stack));
});
py::class_<PyTorchFileWriter>(m, "PyTorchFileWriter")
.def(py::init<std::string>())
.def("write_record", &PyTorchFileWriter::writeRecord)

View File

@ -371,7 +371,14 @@ void initJitScriptBindings(PyObject* module) {
// public.
py::class_<Module, std::shared_ptr<Module>>(m, "ScriptModule")
.def(py::init<>())
.def("save", &Module::save)
.def("save", [](std::shared_ptr<Module> m, const std::string& filename) {
m->save(filename);
})
.def("save_to_buffer", [](std::shared_ptr<Module> m) {
std::ostringstream buf;
m->save(buf);
return py::bytes(buf.str());
})
.def("_set_optimized", &Module::set_optimized)
.def(
"_define",
@ -534,7 +541,13 @@ void initJitScriptBindings(PyObject* module) {
});
m.def("merge_type_from_type_comment", &mergeTypesFromTypeComment);
m.def("import_ir_module", import_ir_module);
m.def("import_ir_module", [](ModuleLookup module_lookup, const std::string& filename) {
import_ir_module(module_lookup, filename);
});
m.def("import_ir_module_from_buffer", [](ModuleLookup module_lookup, const std::string& buffer) {
std::istringstream in(buffer);
import_ir_module(module_lookup, in);
});
}
} // namespace script

View File

@ -71,6 +71,10 @@ void Method::ensure_defined() {
}
}
void Module::save(std::ostream& out) {
ExportModule(*this, out);
}
void Module::save(const std::string& filename) {
ExportModule(*this, filename);
}

View File

@ -20,6 +20,7 @@
#include <string>
#include <unordered_map>
#include <vector>
#include <ostream>
// This file contains classes which assist in desugaring Python style
// modules and their methods into flattened graphs which don't have any
@ -376,6 +377,8 @@ struct Module {
return get_method(method_name)({IValue(std::forward<Types>(args))...});
}
void save(std::ostream& out);
void save(const std::string& filename);
private:

View File

@ -3,6 +3,9 @@
#include <cstdio>
#include <cstring>
#include <cerrno>
#include <istream>
#include <ostream>
#include <fstream>
namespace torch { namespace jit {
@ -75,25 +78,16 @@ namespace {
static constexpr uint64_t kFileFormatVersion = 0x1L;
static constexpr uint8_t kPadValue = 0xEF;
void wrapPErrorAndThrow(const std::string& msg) {
std::ostringstream oss;
oss << msg << " : " << strerror(errno);
throw std::runtime_error(oss.str());
}
} // namespace
class PyTorchFileReader {
class PyTorchStreamReader {
public:
PyTorchFileReader(std::string filename) {
fp = std::fopen(filename.c_str(), "rb");
if (!fp) {
wrapPErrorAndThrow("Couldn't open file for reading!");
}
PyTorchStreamReader(std::istream& in_) : in(in_) {
// Store file size so we know when we're done reading because the f* APIs
// don't do a good job of that
std::fseek(fp, 0L, SEEK_END);
file_size = std::ftell(fp);
std::fseek(fp, 0L, SEEK_SET);
in.seekg(0L, in.end);
file_size = in.tellg();
in.seekg(0L);
readAndValidateFileHeader();
// Do this now since we're reasonably sure this is actually a PyT file from
// the header.
@ -115,7 +109,7 @@ class PyTorchFileReader {
}
// Seek to the provided offset
cursor = key;
std::fseek(fp, cursor, SEEK_SET);
in.seekg(cursor);
auto tag = read64BitIntegerLittleEndian();
if (tag != RecordTags::STORAGE) {
throw std::runtime_error("Attempted to read a record of non-storage type");
@ -124,18 +118,16 @@ class PyTorchFileReader {
seekToNextAlignmentBoundary();
auto ptr = malloc(size);
at::DataPtr retval(ptr, ptr, free, at::kCPU);
if (!std::fread(ptr, size, 1, fp)) {
wrapPErrorAndThrow("Failed to read data from record");
}
in.read((char*)ptr, size);
cursor += size;
seekToNextAlignmentBoundary();
return std::tuple<at::DataPtr, size_t>(std::move(retval), size);
}
~PyTorchFileReader() {
std::fclose(fp);
~PyTorchStreamReader() {
}
private:
FILE *fp;
std::istream& in;
size_t cursor = 0;
size_t file_size;
size_t last_record_offset;
@ -144,8 +136,9 @@ class PyTorchFileReader {
uint64_t read64BitIntegerLittleEndian() {
uint64_t retval;
// TODO endian swap on platforms that need it?
size_t read_bytes = std::fread(&retval, 1u, 8u, fp);
if (read_bytes != 8u) {
in.read(reinterpret_cast<char *>(&retval), 8);
std::streamsize read_bytes = in.gcount();
if (read_bytes != 8) {
std::ostringstream errmsg;
errmsg << "Expected to read 8 bytes but got " << read_bytes;
throw std::runtime_error(errmsg.str());
@ -158,7 +151,7 @@ class PyTorchFileReader {
size_t next_offset = (cursor + kFieldAlignment) - (cursor % kFieldAlignment);
size_t pad_amount = next_offset - cursor;
cursor += pad_amount;
std::fseek(fp, cursor, SEEK_SET);
in.seekg(cursor);
}
// File format deserialization functions
@ -183,7 +176,7 @@ class PyTorchFileReader {
// Seek to location of file footer. We've already validated that the file
// length is a multiple of the alignment size
cursor = file_size - kFieldAlignment;
std::fseek(fp, cursor, SEEK_SET);
in.seekg(cursor);
auto tag = read64BitIntegerLittleEndian();
if (tag != RecordTags::FOOTER) {
throw std::runtime_error("File footer has wrong record type. Is this"
@ -197,13 +190,9 @@ class PyTorchFileReader {
}
};
class PyTorchFileWriter {
class PyTorchStreamWriter {
public:
PyTorchFileWriter(const std::string& filename) {
fp = std::fopen(filename.c_str(), "wb");
if (!fp) {
wrapPErrorAndThrow("Unable to open PyTorch file for writing!");
}
PyTorchStreamWriter(std::ostream& out_) : out(out_) {
writeFileHeader();
// In the case that we do not write any records into this file, the last
// record index written into the footer will point to the footer itself.
@ -224,15 +213,14 @@ class PyTorchFileWriter {
JIT_ASSERT(!finalized);
writeFileFooter();
finalized = true;
std::fclose(fp);
}
~PyTorchFileWriter() {
~PyTorchStreamWriter() {
if (!finalized) {
writeEndOfFile();
}
}
private:
FILE *fp;
std::ostream& out;
size_t cursor = 0;
bool finalized = false;
size_t last_record_idx = 0;
@ -240,17 +228,13 @@ class PyTorchFileWriter {
// Utility functions
void write64BitIntegerLittleEndian(const uint64_t value) {
// TODO endian swap on platforms that need it?
if (!std::fwrite(&value, 8u, 1u, fp)) {
wrapPErrorAndThrow("Unable to write to file!");
}
out.write(reinterpret_cast<const char *>(&value), 8);
cursor += 8u;
}
void writePad(const size_t num_bytes) {
static std::vector<char> pad_buffer(kPadValue, kFieldAlignment);
if (!std::fwrite(pad_buffer.data(), num_bytes, 1u, fp)) {
wrapPErrorAndThrow("Unable to write to file!");
}
out.write(pad_buffer.data(), num_bytes);
cursor += num_bytes;
}
@ -261,9 +245,7 @@ class PyTorchFileWriter {
}
void writeBuffer(const char* data, size_t size) {
if (!std::fwrite(data, size, 1u, fp)) {
wrapPErrorAndThrow("Unable to write to file!");
}
out.write(data, size);
cursor += size;
}
@ -281,5 +263,43 @@ class PyTorchFileWriter {
}
};
class PyTorchFileReader {
public:
PyTorchFileReader(const std::string& filename) :
in(filename, std::ios_base::binary),
stream_reader(in) {}
std::tuple<at::DataPtr, size_t> getLastRecord() {
return stream_reader.getLastRecord();
}
std::tuple<at::DataPtr, size_t> getRecordWithKey(uint64_t key) {
return stream_reader.getRecordWithKey(key);
}
private:
std::ifstream in;
PyTorchStreamReader stream_reader;
};
class PyTorchFileWriter {
public:
PyTorchFileWriter(const std::string& filename) :
out(filename, std::ios_base::binary),
stream_writer(out) {}
uint64_t writeRecord(const char* data, size_t size) {
return stream_writer.writeRecord(data, size);
}
void writeEndOfFile() {
stream_writer.writeEndOfFile();
out.close();
}
private:
std::ofstream out;
PyTorchStreamWriter stream_writer;
};
}} // namespace torch::jit

View File

@ -20,6 +20,8 @@ import copy
import numbers
import collections
import re
if sys.version_info[0] > 2:
import pathlib
def _parse_env(name, default, true_message, false_message):
@ -58,19 +60,27 @@ def scope(scope_name):
tracing_state.pop_scope()
def load(filename):
def load(f):
r"""
Load a ``ScriptModule`` previously saved with :func:`save <torch.jit.ScriptModule.save>`
Load a ``ScriptModule`` previously saved with :func:`save <torch.jit.save>`
.. DANGER::
All previously saved modules, no matter their device, are always loaded onto the CPU.
This is different from :func:`torch.load`'s semantics and may change in the future.
Arguments:
filename (string): the file to load
f: a file-like object (has to implement read, readline, tell, and seek),
or a string containing a file name
Returns:
A ``ScriptModule`` object.
Example:
>>> torch.jit.load('scriptmodule.pt')
# Load ScriptModule from io.BytesIO object
>>> with open('scriptmodule.pt', 'rb') as f:
buffer = io.BytesIO(f.read())
>>> torch.jit.load(buffer)
"""
m = ScriptModule()
@ -82,10 +92,48 @@ def load(filename):
curr = getattr(curr, name)
return curr
torch._C.import_ir_module(module_lookup, filename)
if isinstance(f, str) or \
(sys.version_info[0] == 2 and isinstance(f, unicode)) or \
(sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
torch._C.import_ir_module(module_lookup, f)
else:
torch._C.import_ir_module_from_buffer(module_lookup, f.read())
return m
def save(m, f):
"""
Saves a ScriptModule to a file.
Args:
m: a ScriptModule to save
f: a file-like object (has to implement write and flush) or a string
containing a file name
.. warning::
If you are using Python 2, torch.save does NOT support StringIO.StringIO
as a valid file-like object. This is because the write method should return
the number of bytes written; StringIO.write() does not do this.
Please use something like io.BytesIO instead.
Example:
>>> m = torch.jit.ScriptModule()
>>> # Save to file
>>> torch.jit.save(m, 'scriptmodule.pt')
>>> # Save to io.BytesIO buffer
>>> buffer = io.BytesIO()
>>> torch.jit.save(m, buffer)
"""
if isinstance(f, str) or \
(sys.version_info[0] == 2 and isinstance(f, unicode)) or \
(sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
m.save(f)
else:
ret = m.save_to_buffer()
f.write(ret)
def get_trace_graph(f, args=(), kwargs=None):
"""
Trace a function or model, returning a tuple consisting of the both the