Use streams in JIT serialization, allow JIT serialization to/from buffer (#11932)

Summary:
This PR replaces the use of `std::FILE` with `istream`/`ostream` for JIT serialization.
It uses this mechanism to add the possibility to serialize to/from binary buffers, in addition to files, both in `libtorch` and from Python.

`getExportImportCopy` in `test_jit.py` has been updated so that both file and buffer codepaths are exercised during tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11932

Differential Revision: D10084303

Pulled By: apaszke

fbshipit-source-id: b850801b3932922fa1dbac6fdaed5063d58bc20d
This commit is contained in:
Luca Antiga 2018-09-28 07:41:26 -07:00 committed by Facebook Github Bot
parent d291cf7de6
commit 5be0baefa2
11 changed files with 201 additions and 77 deletions

View File

@ -240,14 +240,10 @@ class JitTestCase(TestCase):
imported = torch.jit.load(f.name) imported = torch.jit.load(f.name)
finally: finally:
os.unlink(f.name) os.unlink(f.name)
f = tempfile.NamedTemporaryFile(delete=False) buffer = io.BytesIO()
try: torch.jit.save(imported, buffer)
f.close() buffer.seek(0)
imported.save(f.name) return torch.jit.load(buffer)
imported = torch.jit.load(f.name)
finally:
os.unlink(f.name)
return imported
def assertGraphContains(self, graph, kind): def assertGraphContains(self, graph, kind):
self.assertTrue(any(n.kind() == kind for n in graph.nodes())) self.assertTrue(any(n.kind() == kind for n in graph.nodes()))

View File

@ -15,6 +15,7 @@
#include <vector> #include <vector>
#include <string> #include <string>
#include <sstream> #include <sstream>
#include <fstream>
namespace torch { namespace jit { namespace torch { namespace jit {
@ -425,7 +426,7 @@ void GraphEncoder::EncodeTensor(
class ModuleEncoder: public EncoderBase { class ModuleEncoder: public EncoderBase {
public: public:
ModuleEncoder(const script::Module &module, ModuleEncoder(const script::Module &module,
const std::string &filename); std::ostream& out);
private: private:
void EncodeModule(onnx::GraphProto *graph_proto, const script::Module &module); void EncodeModule(onnx::GraphProto *graph_proto, const script::Module &module);
@ -448,7 +449,7 @@ class ModuleEncoder: public EncoderBase {
virtual void EncodeTensor(onnx::TensorProto *tensor_proto, virtual void EncodeTensor(onnx::TensorProto *tensor_proto,
const at::Tensor &tensor, const at::Tensor &tensor,
const at::optional<std::string> external_ref) override; const at::optional<std::string> external_ref = {}) override;
virtual void EncodeIntermediateValueInfo(onnx::GraphProto *graph_proto, virtual void EncodeIntermediateValueInfo(onnx::GraphProto *graph_proto,
const Value* n) override; const Value* n) override;
@ -462,7 +463,7 @@ class ModuleEncoder: public EncoderBase {
const TypePtr& type, const TypePtr& type,
const std::string& name); const std::string& name);
PyTorchFileWriter file_writer_; PyTorchStreamWriter stream_writer_;
// Used to deduplicate tensor storages // Used to deduplicate tensor storages
std::unordered_map<const void*, uint64_t> storage_dedup_map_; std::unordered_map<const void*, uint64_t> storage_dedup_map_;
@ -475,9 +476,9 @@ class ModuleEncoder: public EncoderBase {
ModuleEncoder::ModuleEncoder( ModuleEncoder::ModuleEncoder(
const script::Module &module, const script::Module &module,
const std::string &filename) std::ostream& out)
: EncoderBase(onnx_torch::OperatorExportTypes::RAW, false), : EncoderBase(onnx_torch::OperatorExportTypes::RAW, false),
file_writer_(filename) { stream_writer_(out) {
model_proto_.set_doc_string("THIS PROTO IS NOT STANDARD ONNX"); model_proto_.set_doc_string("THIS PROTO IS NOT STANDARD ONNX");
EncodeModule(model_proto_.mutable_graph(), module); EncodeModule(model_proto_.mutable_graph(), module);
} }
@ -586,7 +587,7 @@ void ModuleEncoder::EncodeModule(
EncodeParameters(graph_proto, module, ""); EncodeParameters(graph_proto, module, "");
EncodeMethods(graph_proto, module, ""); EncodeMethods(graph_proto, module, "");
auto str = model_proto_.SerializeAsString(); auto str = model_proto_.SerializeAsString();
file_writer_.writeRecord(str.data(), str.size()); stream_writer_.writeRecord(str.data(), str.size());
} }
void ModuleEncoder::EncodeParameters( void ModuleEncoder::EncodeParameters(
@ -674,7 +675,7 @@ void ModuleEncoder::EncodeMethod(
void ModuleEncoder::EncodeTensor( void ModuleEncoder::EncodeTensor(
onnx::TensorProto *tensor_proto, onnx::TensorProto *tensor_proto,
const at::Tensor &tensor, const at::Tensor &tensor,
const at::optional<std::string> external_ref = {}) { const at::optional<std::string> external_ref) {
auto storage_ptr = tensor.storage().unsafeGetStorageImpl(); auto storage_ptr = tensor.storage().unsafeGetStorageImpl();
auto dedup_it = storage_dedup_map_.find(storage_ptr); auto dedup_it = storage_dedup_map_.find(storage_ptr);
if (dedup_it != storage_dedup_map_.end()) { if (dedup_it != storage_dedup_map_.end()) {
@ -693,7 +694,7 @@ void ModuleEncoder::EncodeTensor(
.cpu(); .cpu();
} }
auto record_number = file_writer_.writeRecord( auto record_number = stream_writer_.writeRecord(
static_cast<char*>(t.storage().data()), t.type().elementSizeInBytes() * t.storage().size()); static_cast<char*>(t.storage().data()), t.type().elementSizeInBytes() * t.storage().size());
tensor_proto->add_int64_data(record_number); tensor_proto->add_int64_data(record_number);
storage_dedup_map_[storage_ptr] = record_number; storage_dedup_map_[storage_ptr] = record_number;
@ -919,8 +920,14 @@ std::tuple<std::string, RawDataExportMap> ExportGraph(
graph_encoder.get_raw_data_export_map()); graph_encoder.get_raw_data_export_map());
} }
void ExportModule(const script::Module& module, std::ostream& out) {
ModuleEncoder(module, out);
}
void ExportModule(const script::Module& module, const std::string &filename) { void ExportModule(const script::Module& module, const std::string &filename) {
ModuleEncoder(module, filename); std::ofstream out(filename, std::ios_base::binary);
ExportModule(module, out);
} }
}} }}

View File

@ -4,6 +4,8 @@
#include "torch/csrc/jit/script/module.h" #include "torch/csrc/jit/script/module.h"
#include "torch/csrc/onnx/onnx.h" #include "torch/csrc/onnx/onnx.h"
#include <ostream>
namespace torch { namespace jit { namespace torch { namespace jit {
// This map is used to keep track of parameters that should be exported // This map is used to keep track of parameters that should be exported
@ -34,6 +36,10 @@ TORCH_API std::string PrettyPrintExportedGraph(
= ::torch::onnx::OperatorExportTypes::ONNX, = ::torch::onnx::OperatorExportTypes::ONNX,
bool google_printer = false); bool google_printer = false);
TORCH_API void ExportModule(
const script::Module& module,
std::ostream& out);
TORCH_API void ExportModule( TORCH_API void ExportModule(
const script::Module& module, const script::Module& module,
const std::string& filename); const std::string& filename);

View File

@ -11,6 +11,7 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include <string> #include <string>
#include <fstream>
namespace torch { namespace jit { namespace torch { namespace jit {
@ -181,7 +182,7 @@ void DecoderBase::buildBlock(const onnx::GraphProto& graph_proto, Block* block,
class ModuleDecoder : DecoderBase { class ModuleDecoder : DecoderBase {
public: public:
ModuleDecoder(ModuleLookup module_lookup, ModuleDecoder(ModuleLookup module_lookup,
const std::string& filename); std::istream& in);
private: private:
virtual std::shared_ptr<Graph> buildGraph(const onnx::GraphProto& graph_proto) override; virtual std::shared_ptr<Graph> buildGraph(const onnx::GraphProto& graph_proto) override;
@ -205,7 +206,7 @@ class ModuleDecoder : DecoderBase {
ModuleLookup module_lookup, ModuleLookup module_lookup,
const std::string fullname); const std::string fullname);
PyTorchFileReader file_reader_; PyTorchStreamReader stream_reader_;
std::unordered_map<uint64_t, std::shared_ptr<at::Storage>> storage_map_; std::unordered_map<uint64_t, std::shared_ptr<at::Storage>> storage_map_;
std::unordered_map<std::string, const onnx::TypeProto*> value_type_map_; std::unordered_map<std::string, const onnx::TypeProto*> value_type_map_;
}; };
@ -319,7 +320,7 @@ at::Tensor ModuleDecoder::buildTensorCommon(
if (storage_it == storage_map_.end()) { if (storage_it == storage_map_.end()) {
at::DataPtr storage_ptr; at::DataPtr storage_ptr;
int64_t size; int64_t size;
std::tie(storage_ptr, size) = file_reader_.getRecordWithKey(record_number); std::tie(storage_ptr, size) = stream_reader_.getRecordWithKey(record_number);
auto storage = std::make_shared<at::Storage>( auto storage = std::make_shared<at::Storage>(
at::CPU(type).typeMeta(), at::CPU(type).typeMeta(),
std::move(storage_ptr), std::move(storage_ptr),
@ -353,10 +354,10 @@ std::pair<std::shared_ptr<script::Module>, std::string> ModuleDecoder::parseFull
ModuleDecoder::ModuleDecoder( ModuleDecoder::ModuleDecoder(
ModuleLookup module_lookup, ModuleLookup module_lookup,
const std::string &filename) : std::istream& in) :
file_reader_(filename) { stream_reader_(in) {
auto model_proto = onnx::ModelProto(); auto model_proto = onnx::ModelProto();
auto record = file_reader_.getLastRecord(); auto record = stream_reader_.getLastRecord();
model_proto.ParsePartialFromArray(std::get<0>(record).get(), std::get<1>(record)); model_proto.ParsePartialFromArray(std::get<0>(record).get(), std::get<1>(record));
auto graph_proto = model_proto.graph(); auto graph_proto = model_proto.graph();
@ -397,11 +398,19 @@ ModuleDecoder::ModuleDecoder(
void import_ir_module( void import_ir_module(
ModuleLookup module_lookup, ModuleLookup module_lookup,
const std::string& filename) { std::istream& in) {
ModuleDecoder(module_lookup, filename); ModuleDecoder(module_lookup, in);
} }
std::shared_ptr<script::Module> load(const std::string& filename) { void import_ir_module(
ModuleLookup module_lookup,
const std::string& filename) {
std::ifstream in(filename, std::ios_base::binary);
ModuleDecoder(module_lookup, in);
}
std::shared_ptr<script::Module> load(std::istream& in) {
auto module = std::make_shared<script::Module>(); auto module = std::make_shared<script::Module>();
auto module_lookup = [&](const std::vector<std::string>& qualified_name) { auto module_lookup = [&](const std::vector<std::string>& qualified_name) {
@ -414,7 +423,17 @@ std::shared_ptr<script::Module> load(const std::string& filename) {
} }
return curr; return curr;
}; };
ModuleDecoder(module_lookup, filename);
ModuleDecoder(module_lookup, in);
return module;
}
std::shared_ptr<script::Module> load(const std::string& filename) {
std::ifstream in(filename, std::ios_base::binary);
auto module = load(in);
return module; return module;
} }

View File

@ -3,6 +3,8 @@
#include "torch/csrc/jit/ir.h" #include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/script/module.h" #include "torch/csrc/jit/script/module.h"
#include <istream>
namespace torch { namespace torch {
namespace jit { namespace jit {
@ -13,11 +15,18 @@ TORCH_API void import_ir_module(
ModuleLookup module_lookup, ModuleLookup module_lookup,
const std::string& filename); const std::string& filename);
TORCH_API void import_ir_module(
ModuleLookup module_lookup,
std::istream& in);
/// Loads a serialized `script::Module` from the given `filename`. /// Loads a serialized `script::Module` from the given `filename`.
/// ///
/// The file stored at the location given in `filename` must contain a /// The file stored at the location given in `filename` must contain a
/// serialized `script::Module`, exported either via `ScriptModule.save()` in /// serialized `script::Module`, exported either via `ScriptModule.save()` in
/// Python or `torch::jit::ExportModule` in C++. /// Python or `torch::jit::ExportModule` in C++.
TORCH_API std::shared_ptr<script::Module> load(std::istream& in);
TORCH_API std::shared_ptr<script::Module> load(const std::string& filename); TORCH_API std::shared_ptr<script::Module> load(const std::string& filename);
} // namespace jit } // namespace jit

View File

@ -227,7 +227,6 @@ void initJITBindings(PyObject *module) {
return createPyObjectForStack(std::move(stack)); return createPyObjectForStack(std::move(stack));
}); });
py::class_<PyTorchFileWriter>(m, "PyTorchFileWriter") py::class_<PyTorchFileWriter>(m, "PyTorchFileWriter")
.def(py::init<std::string>()) .def(py::init<std::string>())
.def("write_record", &PyTorchFileWriter::writeRecord) .def("write_record", &PyTorchFileWriter::writeRecord)

View File

@ -371,7 +371,14 @@ void initJitScriptBindings(PyObject* module) {
// public. // public.
py::class_<Module, std::shared_ptr<Module>>(m, "ScriptModule") py::class_<Module, std::shared_ptr<Module>>(m, "ScriptModule")
.def(py::init<>()) .def(py::init<>())
.def("save", &Module::save) .def("save", [](std::shared_ptr<Module> m, const std::string& filename) {
m->save(filename);
})
.def("save_to_buffer", [](std::shared_ptr<Module> m) {
std::ostringstream buf;
m->save(buf);
return py::bytes(buf.str());
})
.def("_set_optimized", &Module::set_optimized) .def("_set_optimized", &Module::set_optimized)
.def( .def(
"_define", "_define",
@ -534,7 +541,13 @@ void initJitScriptBindings(PyObject* module) {
}); });
m.def("merge_type_from_type_comment", &mergeTypesFromTypeComment); m.def("merge_type_from_type_comment", &mergeTypesFromTypeComment);
m.def("import_ir_module", import_ir_module); m.def("import_ir_module", [](ModuleLookup module_lookup, const std::string& filename) {
import_ir_module(module_lookup, filename);
});
m.def("import_ir_module_from_buffer", [](ModuleLookup module_lookup, const std::string& buffer) {
std::istringstream in(buffer);
import_ir_module(module_lookup, in);
});
} }
} // namespace script } // namespace script

View File

@ -71,6 +71,10 @@ void Method::ensure_defined() {
} }
} }
void Module::save(std::ostream& out) {
ExportModule(*this, out);
}
void Module::save(const std::string& filename) { void Module::save(const std::string& filename) {
ExportModule(*this, filename); ExportModule(*this, filename);
} }

View File

@ -20,6 +20,7 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include <ostream>
// This file contains classes which assist in desugaring Python style // This file contains classes which assist in desugaring Python style
// modules and their methods into flattened graphs which don't have any // modules and their methods into flattened graphs which don't have any
@ -376,6 +377,8 @@ struct Module {
return get_method(method_name)({IValue(std::forward<Types>(args))...}); return get_method(method_name)({IValue(std::forward<Types>(args))...});
} }
void save(std::ostream& out);
void save(const std::string& filename); void save(const std::string& filename);
private: private:

View File

@ -3,6 +3,9 @@
#include <cstdio> #include <cstdio>
#include <cstring> #include <cstring>
#include <cerrno> #include <cerrno>
#include <istream>
#include <ostream>
#include <fstream>
namespace torch { namespace jit { namespace torch { namespace jit {
@ -75,25 +78,16 @@ namespace {
static constexpr uint64_t kFileFormatVersion = 0x1L; static constexpr uint64_t kFileFormatVersion = 0x1L;
static constexpr uint8_t kPadValue = 0xEF; static constexpr uint8_t kPadValue = 0xEF;
void wrapPErrorAndThrow(const std::string& msg) {
std::ostringstream oss;
oss << msg << " : " << strerror(errno);
throw std::runtime_error(oss.str());
}
} // namespace } // namespace
class PyTorchFileReader { class PyTorchStreamReader {
public: public:
PyTorchFileReader(std::string filename) { PyTorchStreamReader(std::istream& in_) : in(in_) {
fp = std::fopen(filename.c_str(), "rb");
if (!fp) {
wrapPErrorAndThrow("Couldn't open file for reading!");
}
// Store file size so we know when we're done reading because the f* APIs // Store file size so we know when we're done reading because the f* APIs
// don't do a good job of that // don't do a good job of that
std::fseek(fp, 0L, SEEK_END); in.seekg(0L, in.end);
file_size = std::ftell(fp); file_size = in.tellg();
std::fseek(fp, 0L, SEEK_SET); in.seekg(0L);
readAndValidateFileHeader(); readAndValidateFileHeader();
// Do this now since we're reasonably sure this is actually a PyT file from // Do this now since we're reasonably sure this is actually a PyT file from
// the header. // the header.
@ -115,7 +109,7 @@ class PyTorchFileReader {
} }
// Seek to the provided offset // Seek to the provided offset
cursor = key; cursor = key;
std::fseek(fp, cursor, SEEK_SET); in.seekg(cursor);
auto tag = read64BitIntegerLittleEndian(); auto tag = read64BitIntegerLittleEndian();
if (tag != RecordTags::STORAGE) { if (tag != RecordTags::STORAGE) {
throw std::runtime_error("Attempted to read a record of non-storage type"); throw std::runtime_error("Attempted to read a record of non-storage type");
@ -124,18 +118,16 @@ class PyTorchFileReader {
seekToNextAlignmentBoundary(); seekToNextAlignmentBoundary();
auto ptr = malloc(size); auto ptr = malloc(size);
at::DataPtr retval(ptr, ptr, free, at::kCPU); at::DataPtr retval(ptr, ptr, free, at::kCPU);
if (!std::fread(ptr, size, 1, fp)) {
wrapPErrorAndThrow("Failed to read data from record"); in.read((char*)ptr, size);
}
cursor += size; cursor += size;
seekToNextAlignmentBoundary(); seekToNextAlignmentBoundary();
return std::tuple<at::DataPtr, size_t>(std::move(retval), size); return std::tuple<at::DataPtr, size_t>(std::move(retval), size);
} }
~PyTorchFileReader() { ~PyTorchStreamReader() {
std::fclose(fp);
} }
private: private:
FILE *fp; std::istream& in;
size_t cursor = 0; size_t cursor = 0;
size_t file_size; size_t file_size;
size_t last_record_offset; size_t last_record_offset;
@ -144,8 +136,9 @@ class PyTorchFileReader {
uint64_t read64BitIntegerLittleEndian() { uint64_t read64BitIntegerLittleEndian() {
uint64_t retval; uint64_t retval;
// TODO endian swap on platforms that need it? // TODO endian swap on platforms that need it?
size_t read_bytes = std::fread(&retval, 1u, 8u, fp); in.read(reinterpret_cast<char *>(&retval), 8);
if (read_bytes != 8u) { std::streamsize read_bytes = in.gcount();
if (read_bytes != 8) {
std::ostringstream errmsg; std::ostringstream errmsg;
errmsg << "Expected to read 8 bytes but got " << read_bytes; errmsg << "Expected to read 8 bytes but got " << read_bytes;
throw std::runtime_error(errmsg.str()); throw std::runtime_error(errmsg.str());
@ -158,7 +151,7 @@ class PyTorchFileReader {
size_t next_offset = (cursor + kFieldAlignment) - (cursor % kFieldAlignment); size_t next_offset = (cursor + kFieldAlignment) - (cursor % kFieldAlignment);
size_t pad_amount = next_offset - cursor; size_t pad_amount = next_offset - cursor;
cursor += pad_amount; cursor += pad_amount;
std::fseek(fp, cursor, SEEK_SET); in.seekg(cursor);
} }
// File format deserialization functions // File format deserialization functions
@ -183,7 +176,7 @@ class PyTorchFileReader {
// Seek to location of file footer. We've already validated that the file // Seek to location of file footer. We've already validated that the file
// length is a multiple of the alignment size // length is a multiple of the alignment size
cursor = file_size - kFieldAlignment; cursor = file_size - kFieldAlignment;
std::fseek(fp, cursor, SEEK_SET); in.seekg(cursor);
auto tag = read64BitIntegerLittleEndian(); auto tag = read64BitIntegerLittleEndian();
if (tag != RecordTags::FOOTER) { if (tag != RecordTags::FOOTER) {
throw std::runtime_error("File footer has wrong record type. Is this" throw std::runtime_error("File footer has wrong record type. Is this"
@ -197,13 +190,9 @@ class PyTorchFileReader {
} }
}; };
class PyTorchFileWriter { class PyTorchStreamWriter {
public: public:
PyTorchFileWriter(const std::string& filename) { PyTorchStreamWriter(std::ostream& out_) : out(out_) {
fp = std::fopen(filename.c_str(), "wb");
if (!fp) {
wrapPErrorAndThrow("Unable to open PyTorch file for writing!");
}
writeFileHeader(); writeFileHeader();
// In the case that we do not write any records into this file, the last // In the case that we do not write any records into this file, the last
// record index written into the footer will point to the footer itself. // record index written into the footer will point to the footer itself.
@ -224,15 +213,14 @@ class PyTorchFileWriter {
JIT_ASSERT(!finalized); JIT_ASSERT(!finalized);
writeFileFooter(); writeFileFooter();
finalized = true; finalized = true;
std::fclose(fp);
} }
~PyTorchFileWriter() { ~PyTorchStreamWriter() {
if (!finalized) { if (!finalized) {
writeEndOfFile(); writeEndOfFile();
} }
} }
private: private:
FILE *fp; std::ostream& out;
size_t cursor = 0; size_t cursor = 0;
bool finalized = false; bool finalized = false;
size_t last_record_idx = 0; size_t last_record_idx = 0;
@ -240,17 +228,13 @@ class PyTorchFileWriter {
// Utility functions // Utility functions
void write64BitIntegerLittleEndian(const uint64_t value) { void write64BitIntegerLittleEndian(const uint64_t value) {
// TODO endian swap on platforms that need it? // TODO endian swap on platforms that need it?
if (!std::fwrite(&value, 8u, 1u, fp)) { out.write(reinterpret_cast<const char *>(&value), 8);
wrapPErrorAndThrow("Unable to write to file!");
}
cursor += 8u; cursor += 8u;
} }
void writePad(const size_t num_bytes) { void writePad(const size_t num_bytes) {
static std::vector<char> pad_buffer(kPadValue, kFieldAlignment); static std::vector<char> pad_buffer(kPadValue, kFieldAlignment);
if (!std::fwrite(pad_buffer.data(), num_bytes, 1u, fp)) { out.write(pad_buffer.data(), num_bytes);
wrapPErrorAndThrow("Unable to write to file!");
}
cursor += num_bytes; cursor += num_bytes;
} }
@ -261,9 +245,7 @@ class PyTorchFileWriter {
} }
void writeBuffer(const char* data, size_t size) { void writeBuffer(const char* data, size_t size) {
if (!std::fwrite(data, size, 1u, fp)) { out.write(data, size);
wrapPErrorAndThrow("Unable to write to file!");
}
cursor += size; cursor += size;
} }
@ -281,5 +263,43 @@ class PyTorchFileWriter {
} }
}; };
class PyTorchFileReader {
public:
PyTorchFileReader(const std::string& filename) :
in(filename, std::ios_base::binary),
stream_reader(in) {}
std::tuple<at::DataPtr, size_t> getLastRecord() {
return stream_reader.getLastRecord();
}
std::tuple<at::DataPtr, size_t> getRecordWithKey(uint64_t key) {
return stream_reader.getRecordWithKey(key);
}
private:
std::ifstream in;
PyTorchStreamReader stream_reader;
};
class PyTorchFileWriter {
public:
PyTorchFileWriter(const std::string& filename) :
out(filename, std::ios_base::binary),
stream_writer(out) {}
uint64_t writeRecord(const char* data, size_t size) {
return stream_writer.writeRecord(data, size);
}
void writeEndOfFile() {
stream_writer.writeEndOfFile();
out.close();
}
private:
std::ofstream out;
PyTorchStreamWriter stream_writer;
};
}} // namespace torch::jit }} // namespace torch::jit

View File

@ -20,6 +20,8 @@ import copy
import numbers import numbers
import collections import collections
import re import re
if sys.version_info[0] > 2:
import pathlib
def _parse_env(name, default, true_message, false_message): def _parse_env(name, default, true_message, false_message):
@ -58,19 +60,27 @@ def scope(scope_name):
tracing_state.pop_scope() tracing_state.pop_scope()
def load(filename): def load(f):
r""" r"""
Load a ``ScriptModule`` previously saved with :func:`save <torch.jit.ScriptModule.save>` Load a ``ScriptModule`` previously saved with :func:`save <torch.jit.save>`
.. DANGER:: .. DANGER::
All previously saved modules, no matter their device, are always loaded onto the CPU. All previously saved modules, no matter their device, are always loaded onto the CPU.
This is different from :func:`torch.load`'s semantics and may change in the future. This is different from :func:`torch.load`'s semantics and may change in the future.
Arguments: Arguments:
filename (string): the file to load f: a file-like object (has to implement read, readline, tell, and seek),
or a string containing a file name
Returns: Returns:
A ``ScriptModule`` object. A ``ScriptModule`` object.
Example:
>>> torch.jit.load('scriptmodule.pt')
# Load ScriptModule from io.BytesIO object
>>> with open('scriptmodule.pt', 'rb') as f:
buffer = io.BytesIO(f.read())
>>> torch.jit.load(buffer)
""" """
m = ScriptModule() m = ScriptModule()
@ -82,10 +92,48 @@ def load(filename):
curr = getattr(curr, name) curr = getattr(curr, name)
return curr return curr
torch._C.import_ir_module(module_lookup, filename) if isinstance(f, str) or \
(sys.version_info[0] == 2 and isinstance(f, unicode)) or \
(sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
torch._C.import_ir_module(module_lookup, f)
else:
torch._C.import_ir_module_from_buffer(module_lookup, f.read())
return m return m
def save(m, f):
"""
Saves a ScriptModule to a file.
Args:
m: a ScriptModule to save
f: a file-like object (has to implement write and flush) or a string
containing a file name
.. warning::
If you are using Python 2, torch.save does NOT support StringIO.StringIO
as a valid file-like object. This is because the write method should return
the number of bytes written; StringIO.write() does not do this.
Please use something like io.BytesIO instead.
Example:
>>> m = torch.jit.ScriptModule()
>>> # Save to file
>>> torch.jit.save(m, 'scriptmodule.pt')
>>> # Save to io.BytesIO buffer
>>> buffer = io.BytesIO()
>>> torch.jit.save(m, buffer)
"""
if isinstance(f, str) or \
(sys.version_info[0] == 2 and isinstance(f, unicode)) or \
(sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
m.save(f)
else:
ret = m.save_to_buffer()
f.write(ret)
def get_trace_graph(f, args=(), kwargs=None): def get_trace_graph(f, args=(), kwargs=None):
""" """
Trace a function or model, returning a tuple consisting of the both the Trace a function or model, returning a tuple consisting of the both the