mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Use streams in JIT serialization, allow JIT serialization to/from buffer (#11932)
Summary: This PR replaces the use of `std::FILE` with `istream`/`ostream` for JIT serialization. It uses this mechanism to add the possibility to serialize to/from binary buffers, in addition to files, both in `libtorch` and from Python. `getExportImportCopy` in `test_jit.py` has been updated so that both file and buffer codepaths are exercised during tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/11932 Differential Revision: D10084303 Pulled By: apaszke fbshipit-source-id: b850801b3932922fa1dbac6fdaed5063d58bc20d
This commit is contained in:
parent
d291cf7de6
commit
5be0baefa2
|
|
@ -240,14 +240,10 @@ class JitTestCase(TestCase):
|
|||
imported = torch.jit.load(f.name)
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
f = tempfile.NamedTemporaryFile(delete=False)
|
||||
try:
|
||||
f.close()
|
||||
imported.save(f.name)
|
||||
imported = torch.jit.load(f.name)
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
return imported
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(imported, buffer)
|
||||
buffer.seek(0)
|
||||
return torch.jit.load(buffer)
|
||||
|
||||
def assertGraphContains(self, graph, kind):
|
||||
self.assertTrue(any(n.kind() == kind for n in graph.nodes()))
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
|
||||
namespace torch { namespace jit {
|
||||
|
||||
|
|
@ -425,7 +426,7 @@ void GraphEncoder::EncodeTensor(
|
|||
class ModuleEncoder: public EncoderBase {
|
||||
public:
|
||||
ModuleEncoder(const script::Module &module,
|
||||
const std::string &filename);
|
||||
std::ostream& out);
|
||||
|
||||
private:
|
||||
void EncodeModule(onnx::GraphProto *graph_proto, const script::Module &module);
|
||||
|
|
@ -448,7 +449,7 @@ class ModuleEncoder: public EncoderBase {
|
|||
|
||||
virtual void EncodeTensor(onnx::TensorProto *tensor_proto,
|
||||
const at::Tensor &tensor,
|
||||
const at::optional<std::string> external_ref) override;
|
||||
const at::optional<std::string> external_ref = {}) override;
|
||||
|
||||
virtual void EncodeIntermediateValueInfo(onnx::GraphProto *graph_proto,
|
||||
const Value* n) override;
|
||||
|
|
@ -462,7 +463,7 @@ class ModuleEncoder: public EncoderBase {
|
|||
const TypePtr& type,
|
||||
const std::string& name);
|
||||
|
||||
PyTorchFileWriter file_writer_;
|
||||
PyTorchStreamWriter stream_writer_;
|
||||
// Used to deduplicate tensor storages
|
||||
std::unordered_map<const void*, uint64_t> storage_dedup_map_;
|
||||
|
||||
|
|
@ -475,9 +476,9 @@ class ModuleEncoder: public EncoderBase {
|
|||
|
||||
ModuleEncoder::ModuleEncoder(
|
||||
const script::Module &module,
|
||||
const std::string &filename)
|
||||
std::ostream& out)
|
||||
: EncoderBase(onnx_torch::OperatorExportTypes::RAW, false),
|
||||
file_writer_(filename) {
|
||||
stream_writer_(out) {
|
||||
model_proto_.set_doc_string("THIS PROTO IS NOT STANDARD ONNX");
|
||||
EncodeModule(model_proto_.mutable_graph(), module);
|
||||
}
|
||||
|
|
@ -586,7 +587,7 @@ void ModuleEncoder::EncodeModule(
|
|||
EncodeParameters(graph_proto, module, "");
|
||||
EncodeMethods(graph_proto, module, "");
|
||||
auto str = model_proto_.SerializeAsString();
|
||||
file_writer_.writeRecord(str.data(), str.size());
|
||||
stream_writer_.writeRecord(str.data(), str.size());
|
||||
}
|
||||
|
||||
void ModuleEncoder::EncodeParameters(
|
||||
|
|
@ -674,7 +675,7 @@ void ModuleEncoder::EncodeMethod(
|
|||
void ModuleEncoder::EncodeTensor(
|
||||
onnx::TensorProto *tensor_proto,
|
||||
const at::Tensor &tensor,
|
||||
const at::optional<std::string> external_ref = {}) {
|
||||
const at::optional<std::string> external_ref) {
|
||||
auto storage_ptr = tensor.storage().unsafeGetStorageImpl();
|
||||
auto dedup_it = storage_dedup_map_.find(storage_ptr);
|
||||
if (dedup_it != storage_dedup_map_.end()) {
|
||||
|
|
@ -693,7 +694,7 @@ void ModuleEncoder::EncodeTensor(
|
|||
.cpu();
|
||||
}
|
||||
|
||||
auto record_number = file_writer_.writeRecord(
|
||||
auto record_number = stream_writer_.writeRecord(
|
||||
static_cast<char*>(t.storage().data()), t.type().elementSizeInBytes() * t.storage().size());
|
||||
tensor_proto->add_int64_data(record_number);
|
||||
storage_dedup_map_[storage_ptr] = record_number;
|
||||
|
|
@ -919,8 +920,14 @@ std::tuple<std::string, RawDataExportMap> ExportGraph(
|
|||
graph_encoder.get_raw_data_export_map());
|
||||
}
|
||||
|
||||
void ExportModule(const script::Module& module, std::ostream& out) {
|
||||
ModuleEncoder(module, out);
|
||||
}
|
||||
|
||||
void ExportModule(const script::Module& module, const std::string &filename) {
|
||||
ModuleEncoder(module, filename);
|
||||
std::ofstream out(filename, std::ios_base::binary);
|
||||
|
||||
ExportModule(module, out);
|
||||
}
|
||||
|
||||
}}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@
|
|||
#include "torch/csrc/jit/script/module.h"
|
||||
#include "torch/csrc/onnx/onnx.h"
|
||||
|
||||
#include <ostream>
|
||||
|
||||
namespace torch { namespace jit {
|
||||
|
||||
// This map is used to keep track of parameters that should be exported
|
||||
|
|
@ -34,6 +36,10 @@ TORCH_API std::string PrettyPrintExportedGraph(
|
|||
= ::torch::onnx::OperatorExportTypes::ONNX,
|
||||
bool google_printer = false);
|
||||
|
||||
TORCH_API void ExportModule(
|
||||
const script::Module& module,
|
||||
std::ostream& out);
|
||||
|
||||
TORCH_API void ExportModule(
|
||||
const script::Module& module,
|
||||
const std::string& filename);
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@
|
|||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
|
||||
namespace torch { namespace jit {
|
||||
|
||||
|
|
@ -181,7 +182,7 @@ void DecoderBase::buildBlock(const onnx::GraphProto& graph_proto, Block* block,
|
|||
class ModuleDecoder : DecoderBase {
|
||||
public:
|
||||
ModuleDecoder(ModuleLookup module_lookup,
|
||||
const std::string& filename);
|
||||
std::istream& in);
|
||||
|
||||
private:
|
||||
virtual std::shared_ptr<Graph> buildGraph(const onnx::GraphProto& graph_proto) override;
|
||||
|
|
@ -205,7 +206,7 @@ class ModuleDecoder : DecoderBase {
|
|||
ModuleLookup module_lookup,
|
||||
const std::string fullname);
|
||||
|
||||
PyTorchFileReader file_reader_;
|
||||
PyTorchStreamReader stream_reader_;
|
||||
std::unordered_map<uint64_t, std::shared_ptr<at::Storage>> storage_map_;
|
||||
std::unordered_map<std::string, const onnx::TypeProto*> value_type_map_;
|
||||
};
|
||||
|
|
@ -319,7 +320,7 @@ at::Tensor ModuleDecoder::buildTensorCommon(
|
|||
if (storage_it == storage_map_.end()) {
|
||||
at::DataPtr storage_ptr;
|
||||
int64_t size;
|
||||
std::tie(storage_ptr, size) = file_reader_.getRecordWithKey(record_number);
|
||||
std::tie(storage_ptr, size) = stream_reader_.getRecordWithKey(record_number);
|
||||
auto storage = std::make_shared<at::Storage>(
|
||||
at::CPU(type).typeMeta(),
|
||||
std::move(storage_ptr),
|
||||
|
|
@ -353,10 +354,10 @@ std::pair<std::shared_ptr<script::Module>, std::string> ModuleDecoder::parseFull
|
|||
|
||||
ModuleDecoder::ModuleDecoder(
|
||||
ModuleLookup module_lookup,
|
||||
const std::string &filename) :
|
||||
file_reader_(filename) {
|
||||
std::istream& in) :
|
||||
stream_reader_(in) {
|
||||
auto model_proto = onnx::ModelProto();
|
||||
auto record = file_reader_.getLastRecord();
|
||||
auto record = stream_reader_.getLastRecord();
|
||||
model_proto.ParsePartialFromArray(std::get<0>(record).get(), std::get<1>(record));
|
||||
auto graph_proto = model_proto.graph();
|
||||
|
||||
|
|
@ -397,11 +398,19 @@ ModuleDecoder::ModuleDecoder(
|
|||
|
||||
void import_ir_module(
|
||||
ModuleLookup module_lookup,
|
||||
const std::string& filename) {
|
||||
ModuleDecoder(module_lookup, filename);
|
||||
std::istream& in) {
|
||||
ModuleDecoder(module_lookup, in);
|
||||
}
|
||||
|
||||
std::shared_ptr<script::Module> load(const std::string& filename) {
|
||||
void import_ir_module(
|
||||
ModuleLookup module_lookup,
|
||||
const std::string& filename) {
|
||||
std::ifstream in(filename, std::ios_base::binary);
|
||||
|
||||
ModuleDecoder(module_lookup, in);
|
||||
}
|
||||
|
||||
std::shared_ptr<script::Module> load(std::istream& in) {
|
||||
auto module = std::make_shared<script::Module>();
|
||||
|
||||
auto module_lookup = [&](const std::vector<std::string>& qualified_name) {
|
||||
|
|
@ -414,7 +423,17 @@ std::shared_ptr<script::Module> load(const std::string& filename) {
|
|||
}
|
||||
return curr;
|
||||
};
|
||||
ModuleDecoder(module_lookup, filename);
|
||||
|
||||
ModuleDecoder(module_lookup, in);
|
||||
|
||||
return module;
|
||||
}
|
||||
|
||||
std::shared_ptr<script::Module> load(const std::string& filename) {
|
||||
std::ifstream in(filename, std::ios_base::binary);
|
||||
|
||||
auto module = load(in);
|
||||
|
||||
return module;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@
|
|||
#include "torch/csrc/jit/ir.h"
|
||||
#include "torch/csrc/jit/script/module.h"
|
||||
|
||||
#include <istream>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
|
|
@ -13,11 +15,18 @@ TORCH_API void import_ir_module(
|
|||
ModuleLookup module_lookup,
|
||||
const std::string& filename);
|
||||
|
||||
TORCH_API void import_ir_module(
|
||||
ModuleLookup module_lookup,
|
||||
std::istream& in);
|
||||
|
||||
/// Loads a serialized `script::Module` from the given `filename`.
|
||||
///
|
||||
/// The file stored at the location given in `filename` must contain a
|
||||
/// serialized `script::Module`, exported either via `ScriptModule.save()` in
|
||||
/// Python or `torch::jit::ExportModule` in C++.
|
||||
|
||||
TORCH_API std::shared_ptr<script::Module> load(std::istream& in);
|
||||
|
||||
TORCH_API std::shared_ptr<script::Module> load(const std::string& filename);
|
||||
|
||||
} // namespace jit
|
||||
|
|
|
|||
|
|
@ -227,7 +227,6 @@ void initJITBindings(PyObject *module) {
|
|||
return createPyObjectForStack(std::move(stack));
|
||||
});
|
||||
|
||||
|
||||
py::class_<PyTorchFileWriter>(m, "PyTorchFileWriter")
|
||||
.def(py::init<std::string>())
|
||||
.def("write_record", &PyTorchFileWriter::writeRecord)
|
||||
|
|
|
|||
|
|
@ -371,7 +371,14 @@ void initJitScriptBindings(PyObject* module) {
|
|||
// public.
|
||||
py::class_<Module, std::shared_ptr<Module>>(m, "ScriptModule")
|
||||
.def(py::init<>())
|
||||
.def("save", &Module::save)
|
||||
.def("save", [](std::shared_ptr<Module> m, const std::string& filename) {
|
||||
m->save(filename);
|
||||
})
|
||||
.def("save_to_buffer", [](std::shared_ptr<Module> m) {
|
||||
std::ostringstream buf;
|
||||
m->save(buf);
|
||||
return py::bytes(buf.str());
|
||||
})
|
||||
.def("_set_optimized", &Module::set_optimized)
|
||||
.def(
|
||||
"_define",
|
||||
|
|
@ -534,7 +541,13 @@ void initJitScriptBindings(PyObject* module) {
|
|||
});
|
||||
|
||||
m.def("merge_type_from_type_comment", &mergeTypesFromTypeComment);
|
||||
m.def("import_ir_module", import_ir_module);
|
||||
m.def("import_ir_module", [](ModuleLookup module_lookup, const std::string& filename) {
|
||||
import_ir_module(module_lookup, filename);
|
||||
});
|
||||
m.def("import_ir_module_from_buffer", [](ModuleLookup module_lookup, const std::string& buffer) {
|
||||
std::istringstream in(buffer);
|
||||
import_ir_module(module_lookup, in);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace script
|
||||
|
|
|
|||
|
|
@ -71,6 +71,10 @@ void Method::ensure_defined() {
|
|||
}
|
||||
}
|
||||
|
||||
void Module::save(std::ostream& out) {
|
||||
ExportModule(*this, out);
|
||||
}
|
||||
|
||||
void Module::save(const std::string& filename) {
|
||||
ExportModule(*this, filename);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@
|
|||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <ostream>
|
||||
|
||||
// This file contains classes which assist in desugaring Python style
|
||||
// modules and their methods into flattened graphs which don't have any
|
||||
|
|
@ -376,6 +377,8 @@ struct Module {
|
|||
return get_method(method_name)({IValue(std::forward<Types>(args))...});
|
||||
}
|
||||
|
||||
void save(std::ostream& out);
|
||||
|
||||
void save(const std::string& filename);
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,9 @@
|
|||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <cerrno>
|
||||
#include <istream>
|
||||
#include <ostream>
|
||||
#include <fstream>
|
||||
|
||||
namespace torch { namespace jit {
|
||||
|
||||
|
|
@ -75,25 +78,16 @@ namespace {
|
|||
static constexpr uint64_t kFileFormatVersion = 0x1L;
|
||||
static constexpr uint8_t kPadValue = 0xEF;
|
||||
|
||||
void wrapPErrorAndThrow(const std::string& msg) {
|
||||
std::ostringstream oss;
|
||||
oss << msg << " : " << strerror(errno);
|
||||
throw std::runtime_error(oss.str());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
class PyTorchFileReader {
|
||||
class PyTorchStreamReader {
|
||||
public:
|
||||
PyTorchFileReader(std::string filename) {
|
||||
fp = std::fopen(filename.c_str(), "rb");
|
||||
if (!fp) {
|
||||
wrapPErrorAndThrow("Couldn't open file for reading!");
|
||||
}
|
||||
PyTorchStreamReader(std::istream& in_) : in(in_) {
|
||||
// Store file size so we know when we're done reading because the f* APIs
|
||||
// don't do a good job of that
|
||||
std::fseek(fp, 0L, SEEK_END);
|
||||
file_size = std::ftell(fp);
|
||||
std::fseek(fp, 0L, SEEK_SET);
|
||||
in.seekg(0L, in.end);
|
||||
file_size = in.tellg();
|
||||
in.seekg(0L);
|
||||
readAndValidateFileHeader();
|
||||
// Do this now since we're reasonably sure this is actually a PyT file from
|
||||
// the header.
|
||||
|
|
@ -115,7 +109,7 @@ class PyTorchFileReader {
|
|||
}
|
||||
// Seek to the provided offset
|
||||
cursor = key;
|
||||
std::fseek(fp, cursor, SEEK_SET);
|
||||
in.seekg(cursor);
|
||||
auto tag = read64BitIntegerLittleEndian();
|
||||
if (tag != RecordTags::STORAGE) {
|
||||
throw std::runtime_error("Attempted to read a record of non-storage type");
|
||||
|
|
@ -124,18 +118,16 @@ class PyTorchFileReader {
|
|||
seekToNextAlignmentBoundary();
|
||||
auto ptr = malloc(size);
|
||||
at::DataPtr retval(ptr, ptr, free, at::kCPU);
|
||||
if (!std::fread(ptr, size, 1, fp)) {
|
||||
wrapPErrorAndThrow("Failed to read data from record");
|
||||
}
|
||||
|
||||
in.read((char*)ptr, size);
|
||||
cursor += size;
|
||||
seekToNextAlignmentBoundary();
|
||||
return std::tuple<at::DataPtr, size_t>(std::move(retval), size);
|
||||
}
|
||||
~PyTorchFileReader() {
|
||||
std::fclose(fp);
|
||||
~PyTorchStreamReader() {
|
||||
}
|
||||
private:
|
||||
FILE *fp;
|
||||
std::istream& in;
|
||||
size_t cursor = 0;
|
||||
size_t file_size;
|
||||
size_t last_record_offset;
|
||||
|
|
@ -144,8 +136,9 @@ class PyTorchFileReader {
|
|||
uint64_t read64BitIntegerLittleEndian() {
|
||||
uint64_t retval;
|
||||
// TODO endian swap on platforms that need it?
|
||||
size_t read_bytes = std::fread(&retval, 1u, 8u, fp);
|
||||
if (read_bytes != 8u) {
|
||||
in.read(reinterpret_cast<char *>(&retval), 8);
|
||||
std::streamsize read_bytes = in.gcount();
|
||||
if (read_bytes != 8) {
|
||||
std::ostringstream errmsg;
|
||||
errmsg << "Expected to read 8 bytes but got " << read_bytes;
|
||||
throw std::runtime_error(errmsg.str());
|
||||
|
|
@ -158,7 +151,7 @@ class PyTorchFileReader {
|
|||
size_t next_offset = (cursor + kFieldAlignment) - (cursor % kFieldAlignment);
|
||||
size_t pad_amount = next_offset - cursor;
|
||||
cursor += pad_amount;
|
||||
std::fseek(fp, cursor, SEEK_SET);
|
||||
in.seekg(cursor);
|
||||
}
|
||||
|
||||
// File format deserialization functions
|
||||
|
|
@ -183,7 +176,7 @@ class PyTorchFileReader {
|
|||
// Seek to location of file footer. We've already validated that the file
|
||||
// length is a multiple of the alignment size
|
||||
cursor = file_size - kFieldAlignment;
|
||||
std::fseek(fp, cursor, SEEK_SET);
|
||||
in.seekg(cursor);
|
||||
auto tag = read64BitIntegerLittleEndian();
|
||||
if (tag != RecordTags::FOOTER) {
|
||||
throw std::runtime_error("File footer has wrong record type. Is this"
|
||||
|
|
@ -197,13 +190,9 @@ class PyTorchFileReader {
|
|||
}
|
||||
};
|
||||
|
||||
class PyTorchFileWriter {
|
||||
class PyTorchStreamWriter {
|
||||
public:
|
||||
PyTorchFileWriter(const std::string& filename) {
|
||||
fp = std::fopen(filename.c_str(), "wb");
|
||||
if (!fp) {
|
||||
wrapPErrorAndThrow("Unable to open PyTorch file for writing!");
|
||||
}
|
||||
PyTorchStreamWriter(std::ostream& out_) : out(out_) {
|
||||
writeFileHeader();
|
||||
// In the case that we do not write any records into this file, the last
|
||||
// record index written into the footer will point to the footer itself.
|
||||
|
|
@ -224,15 +213,14 @@ class PyTorchFileWriter {
|
|||
JIT_ASSERT(!finalized);
|
||||
writeFileFooter();
|
||||
finalized = true;
|
||||
std::fclose(fp);
|
||||
}
|
||||
~PyTorchFileWriter() {
|
||||
~PyTorchStreamWriter() {
|
||||
if (!finalized) {
|
||||
writeEndOfFile();
|
||||
}
|
||||
}
|
||||
private:
|
||||
FILE *fp;
|
||||
std::ostream& out;
|
||||
size_t cursor = 0;
|
||||
bool finalized = false;
|
||||
size_t last_record_idx = 0;
|
||||
|
|
@ -240,17 +228,13 @@ class PyTorchFileWriter {
|
|||
// Utility functions
|
||||
void write64BitIntegerLittleEndian(const uint64_t value) {
|
||||
// TODO endian swap on platforms that need it?
|
||||
if (!std::fwrite(&value, 8u, 1u, fp)) {
|
||||
wrapPErrorAndThrow("Unable to write to file!");
|
||||
}
|
||||
out.write(reinterpret_cast<const char *>(&value), 8);
|
||||
cursor += 8u;
|
||||
}
|
||||
|
||||
void writePad(const size_t num_bytes) {
|
||||
static std::vector<char> pad_buffer(kPadValue, kFieldAlignment);
|
||||
if (!std::fwrite(pad_buffer.data(), num_bytes, 1u, fp)) {
|
||||
wrapPErrorAndThrow("Unable to write to file!");
|
||||
}
|
||||
out.write(pad_buffer.data(), num_bytes);
|
||||
cursor += num_bytes;
|
||||
}
|
||||
|
||||
|
|
@ -261,9 +245,7 @@ class PyTorchFileWriter {
|
|||
}
|
||||
|
||||
void writeBuffer(const char* data, size_t size) {
|
||||
if (!std::fwrite(data, size, 1u, fp)) {
|
||||
wrapPErrorAndThrow("Unable to write to file!");
|
||||
}
|
||||
out.write(data, size);
|
||||
cursor += size;
|
||||
}
|
||||
|
||||
|
|
@ -281,5 +263,43 @@ class PyTorchFileWriter {
|
|||
}
|
||||
};
|
||||
|
||||
class PyTorchFileReader {
|
||||
public:
|
||||
PyTorchFileReader(const std::string& filename) :
|
||||
in(filename, std::ios_base::binary),
|
||||
stream_reader(in) {}
|
||||
|
||||
std::tuple<at::DataPtr, size_t> getLastRecord() {
|
||||
return stream_reader.getLastRecord();
|
||||
}
|
||||
|
||||
std::tuple<at::DataPtr, size_t> getRecordWithKey(uint64_t key) {
|
||||
return stream_reader.getRecordWithKey(key);
|
||||
}
|
||||
|
||||
private:
|
||||
std::ifstream in;
|
||||
PyTorchStreamReader stream_reader;
|
||||
};
|
||||
|
||||
class PyTorchFileWriter {
|
||||
public:
|
||||
PyTorchFileWriter(const std::string& filename) :
|
||||
out(filename, std::ios_base::binary),
|
||||
stream_writer(out) {}
|
||||
|
||||
uint64_t writeRecord(const char* data, size_t size) {
|
||||
return stream_writer.writeRecord(data, size);
|
||||
}
|
||||
|
||||
void writeEndOfFile() {
|
||||
stream_writer.writeEndOfFile();
|
||||
out.close();
|
||||
}
|
||||
|
||||
private:
|
||||
std::ofstream out;
|
||||
PyTorchStreamWriter stream_writer;
|
||||
};
|
||||
|
||||
}} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ import copy
|
|||
import numbers
|
||||
import collections
|
||||
import re
|
||||
if sys.version_info[0] > 2:
|
||||
import pathlib
|
||||
|
||||
|
||||
def _parse_env(name, default, true_message, false_message):
|
||||
|
|
@ -58,19 +60,27 @@ def scope(scope_name):
|
|||
tracing_state.pop_scope()
|
||||
|
||||
|
||||
def load(filename):
|
||||
def load(f):
|
||||
r"""
|
||||
Load a ``ScriptModule`` previously saved with :func:`save <torch.jit.ScriptModule.save>`
|
||||
Load a ``ScriptModule`` previously saved with :func:`save <torch.jit.save>`
|
||||
|
||||
.. DANGER::
|
||||
All previously saved modules, no matter their device, are always loaded onto the CPU.
|
||||
This is different from :func:`torch.load`'s semantics and may change in the future.
|
||||
|
||||
Arguments:
|
||||
filename (string): the file to load
|
||||
f: a file-like object (has to implement read, readline, tell, and seek),
|
||||
or a string containing a file name
|
||||
|
||||
Returns:
|
||||
A ``ScriptModule`` object.
|
||||
|
||||
Example:
|
||||
>>> torch.jit.load('scriptmodule.pt')
|
||||
# Load ScriptModule from io.BytesIO object
|
||||
>>> with open('scriptmodule.pt', 'rb') as f:
|
||||
buffer = io.BytesIO(f.read())
|
||||
>>> torch.jit.load(buffer)
|
||||
"""
|
||||
m = ScriptModule()
|
||||
|
||||
|
|
@ -82,10 +92,48 @@ def load(filename):
|
|||
curr = getattr(curr, name)
|
||||
return curr
|
||||
|
||||
torch._C.import_ir_module(module_lookup, filename)
|
||||
if isinstance(f, str) or \
|
||||
(sys.version_info[0] == 2 and isinstance(f, unicode)) or \
|
||||
(sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
|
||||
torch._C.import_ir_module(module_lookup, f)
|
||||
else:
|
||||
torch._C.import_ir_module_from_buffer(module_lookup, f.read())
|
||||
return m
|
||||
|
||||
|
||||
def save(m, f):
|
||||
"""
|
||||
Saves a ScriptModule to a file.
|
||||
|
||||
Args:
|
||||
m: a ScriptModule to save
|
||||
f: a file-like object (has to implement write and flush) or a string
|
||||
containing a file name
|
||||
|
||||
.. warning::
|
||||
If you are using Python 2, torch.save does NOT support StringIO.StringIO
|
||||
as a valid file-like object. This is because the write method should return
|
||||
the number of bytes written; StringIO.write() does not do this.
|
||||
|
||||
Please use something like io.BytesIO instead.
|
||||
|
||||
Example:
|
||||
>>> m = torch.jit.ScriptModule()
|
||||
>>> # Save to file
|
||||
>>> torch.jit.save(m, 'scriptmodule.pt')
|
||||
>>> # Save to io.BytesIO buffer
|
||||
>>> buffer = io.BytesIO()
|
||||
>>> torch.jit.save(m, buffer)
|
||||
"""
|
||||
if isinstance(f, str) or \
|
||||
(sys.version_info[0] == 2 and isinstance(f, unicode)) or \
|
||||
(sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
|
||||
m.save(f)
|
||||
else:
|
||||
ret = m.save_to_buffer()
|
||||
f.write(ret)
|
||||
|
||||
|
||||
def get_trace_graph(f, args=(), kwargs=None):
|
||||
"""
|
||||
Trace a function or model, returning a tuple consisting of the both the
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user