#include #ifdef FLATBUFFERS_VERSION_MAJOR #error "flatbuffer_serializer.h must not include any flatbuffers headers" #endif // FLATBUFFERS_VERSION_MAJOR #include #include #include #include #include #include #include #include #include #include #include #include #include #include #if defined(FB_XPLAT_BUILD) || defined(FBCODE_CAFFE2) #include // NOLINT namespace flatbuffers = flatbuffers_fbsource; #define FLATBUFFERS_MAX_ALIGNMENT FLATBUFFERS_FBSOURCE_MAX_ALIGNMENT #else #include // NOLINT #endif namespace torch::jit { using flatbuffers::FlatBufferBuilder; using mobile::serialization::CreateArg; using mobile::serialization::CreateDebugInfo; using mobile::serialization::CreateDict; using mobile::serialization::CreateFunctionDirect; using mobile::serialization::CreateIValue; using mobile::serialization::CreateList; using mobile::serialization::CreateModule; using mobile::serialization::CreateObject; using mobile::serialization::CreateOperator; using mobile::serialization::CreateTensorMetadataDirect; using mobile::serialization::CreateTupleDirect; namespace { // TODO: remove once caffe2::kProducedBytecodeVersion is >= 9 and flatbuffer is // launched. constexpr uint32_t kMinVersion = 9; // We will store IValue NONE in index 0 in flatbuffer. constexpr int kNoneIndex = 0; static TypePtr realType(TypePtr type) { if (auto dyn = type->castRaw()) { return dyn->fallback(); } else { return type; } } auto print_type(const c10::Type& t) -> std::optional { auto namedType = t.cast(); if (namedType && namedType->name()) { return namedType->name().value().qualifiedName(); } if (auto dyn = t.castRaw()) { return dyn->fallback()->annotation_str(); } return std::nullopt; } class FlatbufferSerializer { public: FlatbufferSerializer() = default; flatbuffers::DetachedBuffer serializeModule( const mobile::Module& module, bool include_tensor_data_in_flatbuffer, const ExtraFilesMap& extra_files = ExtraFilesMap(), const ExtraFilesMap& jit_sources = ExtraFilesMap(), const std::vector& jit_constants = {}); private: template std::vector storeIValuesAndGetIndexes( flatbuffers::FlatBufferBuilder& fbb, It begin, It end) { std::vector indexes; for (; begin != end; ++begin) { indexes.push_back(storeIValueAndGetIndex(fbb, *begin)); } return indexes; } flatbuffers::Offset tupleToFB( flatbuffers::FlatBufferBuilder& fbb, const IValue& tuple); flatbuffers::Offset listToFB( flatbuffers::FlatBufferBuilder& fbb, const IValue& list); flatbuffers::Offset dictToFB( flatbuffers::FlatBufferBuilder& fbb, const IValue& list); flatbuffers::Offset objectToFB( flatbuffers::FlatBufferBuilder& fbb, const IValue& ivalue); flatbuffers::Offset tensorToFB( flatbuffers::FlatBufferBuilder& fbb, const IValue& ivalue); flatbuffers::Offset functionToFB( flatbuffers::FlatBufferBuilder& fbb, const std::string& qn, const mobile::Function& func); flatbuffers::Offset iValueToFB( flatbuffers::FlatBufferBuilder& fbb, const IValue& ivalue); flatbuffers::Offset CreateFBSchema( flatbuffers::FlatBufferBuilder& fbb, const std::vector& args, const std::vector& returns, const c10::TypePrinter& type_printer); flatbuffers::Offset classTypeToFB( flatbuffers::FlatBufferBuilder& fbb, const ClassTypePtr& class_ptr); uint32_t storeIValueAndGetIndex( flatbuffers::FlatBufferBuilder& fbb, const IValue& ivalue); uint32_t storeFunctionAndGetIndex( flatbuffers::FlatBufferBuilder& fbb, const std::string& qn, const mobile::Function& function); uint32_t storeClassTypeAndGetIndex( flatbuffers::FlatBufferBuilder& fbb, const ClassTypePtr& class_type); flatbuffers::Offset>> storeExtraFilesAndGetOffset( FlatBufferBuilder& fbb, const ExtraFilesMap& extra_files); uint32_t insertIValue( flatbuffers::Offset ivalue) { uint32_t size = ivalue_offsets_.size(); ivalue_offsets_.push_back(ivalue); return size; } std::vector tensor_data_; std::unordered_map memoized_storage_map_; std::vector> ivalue_offsets_; std::vector> obj_types_offset_; // qualified name to serialized class, type or function std::unordered_map qn_to_serialized_values_; // cache of some ivalues struct IValueHash { size_t operator()(const IValue& val) const { return IValue::hash(val); } }; struct IValueEqual { // Copy of this // https://www.internalfb.com/code/aros/[3b875bce7ffa2adacdcea9b3e0cb6d304737a193]/xros/third-party/caffe2/caffe2/aten/src/ATen/core/ivalue.cpp?lines=266 // but without relying on aten::nonzero operator being present in the // binary. bool operator()(const IValue& lhs, const IValue& rhs) const { // The only case we don't return bool is for tensor comparison. Lets do // pointer comparison here. if (lhs.isTensor() || rhs.isTensor()) { if (lhs.isTensor() && rhs.isTensor()) { return (&lhs.toTensor()) == (&rhs.toTensor()); } return false; } IValue eq = lhs.equals(rhs); if (eq.isBool()) { return eq.toBool(); } return false; } }; std::unordered_map cached_ivalues_; const mobile::CompilationUnit* mcu_ = nullptr; }; flatbuffers::Offset FlatbufferSerializer:: CreateFBSchema( flatbuffers::FlatBufferBuilder& fbb, const std::vector& args, const std::vector& returns, const c10::TypePrinter& type_printer) { std::vector> arg_vec; arg_vec.reserve(args.size()); std::vector> return_vec; return_vec.reserve(returns.size()); for (const auto& arg : args) { auto index = storeIValueAndGetIndex(fbb, arg.default_value()); arg_vec.emplace_back(CreateArg( fbb, fbb.CreateSharedString(arg.name()), fbb.CreateSharedString( realType(arg.type())->annotation_str(type_printer)), index)); } for (const auto& ret : returns) { auto index = storeIValueAndGetIndex(fbb, ret.default_value()); return_vec.emplace_back(CreateArg( fbb, fbb.CreateSharedString(ret.name()), fbb.CreateSharedString( realType(ret.type())->annotation_str(type_printer)), index)); } return CreateSchema( fbb, fbb.CreateVector(arg_vec), fbb.CreateVector(return_vec)); } flatbuffers::Offset FlatbufferSerializer:: functionToFB( FlatBufferBuilder& fbb, const std::string& qn, const mobile::Function& func) { const auto& code = func.get_code(); // instructions std::vector instruction_vector; instruction_vector.reserve(code.instructions_.size()); for (const auto& inst : code.instructions_) { instruction_vector.emplace_back(inst.op, inst.N, inst.X); } // operators std::vector> operator_vector; operator_vector.reserve(code.op_names_.size()); for (const auto i : c10::irange(code.op_names_.size())) { const auto& opname = code.op_names_[i]; const int op_size = code.operator_input_sizes_[i]; operator_vector.push_back(CreateOperator( fbb, fbb.CreateSharedString(opname.name), fbb.CreateSharedString(opname.overload_name), op_size)); } const auto& constants = code.constants_; std::vector constant_indexes; constant_indexes.reserve(constants.size()); for (const auto& constant : constants) { constant_indexes.push_back(storeIValueAndGetIndex(fbb, constant)); } // types static const std::string torch_prefix("__torch__"); static const std::string class_prefix("__torch__.torch.classes"); std::vector> type_offsets; for (const TypePtr& t : code.types_) { auto type_str = realType(t)->annotation_str(); if (type_str.find(torch_prefix) == 0) { TORCH_CHECK( type_str.find(class_prefix) == 0, "__torch__ types other than custom c++ classes (__torch__.torch.classes)" "are not supported in lite interpreter. ", "Workaround: instead of using arbitrary class type (class Foo()), ", "define a pytorch class (class Foo(torch.nn.Module))."); } type_offsets.push_back(fbb.CreateSharedString(type_str)); } // since the register location is embedded into the bytecode, pass the // register size auto register_size = static_cast(code.register_size_); // schema auto type_printer = [&](const c10::Type& t) -> std::optional { auto namedType = t.cast(); if (namedType && namedType->name()) { return namedType->name().value().qualifiedName(); } if (auto dyn = t.castRaw()) { return dyn->fallback()->annotation_str(); } return std::nullopt; }; flatbuffers::Offset schema_offset = 0; uint32_t class_index = 0; if (func.hasSchema()) { const auto& schema = func.getSchema(); TORCH_CHECK( schema.overload_name().empty(), // @TODO: is this check correct? "Overloads are not supported in mobile modules."); TORCH_CHECK( !schema.is_vararg(), "Python *args are not supported in mobile modules."); TORCH_CHECK( !schema.is_varret(), "A variable number of return values is not supported in mobile modules."); schema_offset = CreateFBSchema(fbb, schema.arguments(), schema.returns(), type_printer); auto classtype = schema.arguments()[0].type()->cast(); class_index = storeClassTypeAndGetIndex(fbb, classtype); } auto debug_info_offset = CreateDebugInfo(fbb, fbb.CreateVector(code.debug_handles_)); auto function_offset = CreateFunctionDirect( fbb, qn.c_str(), &instruction_vector, &operator_vector, &constant_indexes, &type_offsets, register_size, schema_offset, debug_info_offset, class_index); return function_offset; } flatbuffers::Offset< flatbuffers::Vector>> FlatbufferSerializer::storeExtraFilesAndGetOffset( FlatBufferBuilder& fbb, const ExtraFilesMap& extra_files) { std::vector> extra_file_offsets; for (const auto& extra_file : extra_files) { flatbuffers::Offset extra_file_offset = mobile::serialization::CreateExtraFile( fbb, fbb.CreateSharedString(extra_file.first), fbb.CreateString(extra_file.second)); extra_file_offsets.emplace_back(extra_file_offset); } return fbb.CreateVector(extra_file_offsets); } flatbuffers::DetachedBuffer FlatbufferSerializer::serializeModule( const mobile::Module& module, bool include_tensor_data_in_flatbuffer, const ExtraFilesMap& extra_files, const ExtraFilesMap& jit_sources, const std::vector& jit_constants) { FlatBufferBuilder fbb; mcu_ = &module.compilation_unit(); // first element is None. insertIValue(CreateIValue(fbb, mobile::serialization::IValueUnion::NONE, 0)); auto methods = module.get_methods(); std::vector functions_index; functions_index.reserve(methods.size()); for (const auto& method : methods) { auto func_offset = storeFunctionAndGetIndex( fbb, method.function().qualname().qualifiedName(), method.function()); functions_index.push_back(func_offset); } auto functions_offset = fbb.CreateVector(functions_index); uint32_t ivalue_index = storeIValueAndGetIndex(fbb, module._ivalue()); flatbuffers::Offset>> storage_data_offset = 0; auto extra_files_offset = storeExtraFilesAndGetOffset(fbb, extra_files); auto jit_source_offset = storeExtraFilesAndGetOffset(fbb, jit_sources); std::vector jit_constants_indexes; jit_constants_indexes.reserve(jit_constants.size()); const uint32_t mobile_ivalue_size = ivalue_offsets_.size(); for (const auto& ival : jit_constants) { jit_constants_indexes.emplace_back(storeIValueAndGetIndex(fbb, ival)); } const uint32_t operator_version = static_cast(module.min_operator_version()); uint32_t bytecode_version = static_cast(module.bytecode_version()); if (bytecode_version < kMinVersion) { bytecode_version = kMinVersion; } // NOTE: saving of storage has to be the last thing to do. if (include_tensor_data_in_flatbuffer) { std::vector> storage_data; for (auto td : tensor_data_) { if (td.storage().device_type() != DeviceType::CPU) { td = at::empty({0}, td.options()) .set_( td.storage(), /* storage_offset = */ 0, /* size = */ {static_cast( td.storage().nbytes() / td.element_size())}, /* stride = */ {1}) .cpu(); } fbb.ForceVectorAlignment( td.storage().nbytes(), sizeof(uint8_t), FLATBUFFERS_MAX_ALIGNMENT); auto storage_offset = mobile::serialization::CreateStorageData( fbb, fbb.CreateVector( reinterpret_cast(td.storage().data()), td.storage().nbytes())); storage_data.push_back(storage_offset); } storage_data_offset = fbb.CreateVector(storage_data); } auto mod = CreateModule( fbb, /*bytecode_version=*/bytecode_version, extra_files_offset, /* extra_files */ functions_offset, ivalue_index, fbb.CreateVector(ivalue_offsets_), static_cast(tensor_data_.size()), storage_data_offset, fbb.CreateVector(obj_types_offset_), jit_source_offset, fbb.CreateVector(jit_constants_indexes), operator_version, mobile_ivalue_size); FinishModuleBuffer(fbb, mod); return fbb.Release(); } flatbuffers::Offset FlatbufferSerializer:: tupleToFB(flatbuffers::FlatBufferBuilder& fbb, const IValue& tuple) { const auto& elements = tuple.toTuple()->elements(); std::vector items = storeIValuesAndGetIndexes(fbb, elements.begin(), elements.end()); return CreateTupleDirect(fbb, &items); } flatbuffers::Offset FlatbufferSerializer::listToFB( flatbuffers::FlatBufferBuilder& fbb, const IValue& list) { const auto& elements = list.toList(); std::vector items = storeIValuesAndGetIndexes(fbb, elements.begin(), elements.end()); return CreateList( fbb, fbb.CreateVector(items), fbb.CreateSharedString( realType(list.type())->annotation_str(print_type))); } flatbuffers::Offset FlatbufferSerializer::dictToFB( flatbuffers::FlatBufferBuilder& fbb, const IValue& ivalue) { const auto& dict = ivalue.toGenericDict(); std::vector keys; std::vector values; keys.reserve(dict.size()); values.reserve(dict.size()); for (const auto& entry : dict) { auto key_index = storeIValueAndGetIndex(fbb, entry.key()); keys.push_back(key_index); auto value_index = storeIValueAndGetIndex(fbb, entry.value()); values.push_back(value_index); } return CreateDict( fbb, fbb.CreateVector(keys), fbb.CreateVector(values), fbb.CreateSharedString( realType(ivalue.type())->annotation_str(print_type))); } flatbuffers::Offset FlatbufferSerializer:: classTypeToFB(FlatBufferBuilder& fbb, const ClassTypePtr& class_ptr) { mobile::serialization::TypeType typetype = mobile::serialization::TypeType::UNSET; flatbuffers::Offset< flatbuffers::Vector>> names_offset = 0; c10::QualifiedName setstate_name(*class_ptr->name(), "__setstate__"); c10::QualifiedName getstate_name(*class_ptr->name(), "__getstate__"); const mobile::Function* setstate = mcu_->find_function(setstate_name); const mobile::Function* getstate = mcu_->find_function(getstate_name); if (setstate != nullptr && getstate != nullptr) { typetype = mobile::serialization::TypeType::CLASS_WITH_SETSTATE; } else if ( class_ptr->findMethod("__setstate__") && class_ptr->findMethod("__getstate__")) { typetype = mobile::serialization::TypeType::CUSTOM_CLASS; } else { size_t num_attr = class_ptr->numAttributes(); std::vector> names; names.reserve(num_attr); for (size_t i = 0; i < num_attr; ++i) { names.push_back(fbb.CreateSharedString(class_ptr->getAttributeName(i))); } names_offset = fbb.CreateVector(names); typetype = mobile::serialization::TypeType::CLASS_WITH_FIELD; } auto name_offset = fbb.CreateString(class_ptr->name()->qualifiedName()); return CreateObjectType(fbb, name_offset, typetype, names_offset); } uint32_t FlatbufferSerializer::storeFunctionAndGetIndex( flatbuffers::FlatBufferBuilder& fbb, const std::string& qn, const mobile::Function& function) { auto iter = qn_to_serialized_values_.find(qn); if (iter != qn_to_serialized_values_.end()) { return iter->second; } auto offset = CreateIValue( fbb, mobile::serialization::IValueUnion::Function, functionToFB(fbb, qn, function).Union()); uint32_t index = insertIValue(offset); qn_to_serialized_values_[qn] = index; return index; } uint32_t FlatbufferSerializer::storeClassTypeAndGetIndex( FlatBufferBuilder& fbb, const ClassTypePtr& class_ptr) { const auto& type_str = class_ptr->name()->qualifiedName(); auto iter = qn_to_serialized_values_.find(type_str); if (iter != qn_to_serialized_values_.end()) { return iter->second; } auto offset = classTypeToFB(fbb, class_ptr); uint32_t res = obj_types_offset_.size(); obj_types_offset_.push_back(offset); qn_to_serialized_values_[type_str] = res; return res; } flatbuffers::Offset FlatbufferSerializer:: objectToFB(flatbuffers::FlatBufferBuilder& fbb, const IValue& ivalue) { auto obj = ivalue.toObject(); auto type = obj->type(); // rename type? // check getstate // save state as ivalue flatbuffers::Offset> attrs = 0; uint32_t state_index = 0; uint32_t setstate_func_index = 0; const auto qn = type->name()->qualifiedName() + ".__setstate__"; auto getstate = type->findMethod("__getstate__"); auto setstate = type->findMethod("__setstate__"); if (getstate && setstate) { auto state = (*getstate)({obj}); state_index = storeIValueAndGetIndex(fbb, state); auto func_index = qn_to_serialized_values_.find(qn); if (func_index != qn_to_serialized_values_.end()) { setstate_func_index = func_index->second; } } else { size_t num_attr = type->numAttributes(); std::vector tuple_index; tuple_index.reserve(num_attr); for (size_t i = 0; i < num_attr; ++i) { tuple_index.push_back(storeIValueAndGetIndex(fbb, obj->getSlot(i))); } attrs = fbb.CreateVector(tuple_index); } uint32_t type_index = storeClassTypeAndGetIndex(fbb, type); return CreateObject(fbb, type_index, state_index, attrs, setstate_func_index); } flatbuffers::Offset FlatbufferSerializer:: FlatbufferSerializer::tensorToFB( flatbuffers::FlatBufferBuilder& fbb, const IValue& ivalue) { auto& tensor = ivalue.toTensor(); bool quantized = tensor.is_quantized(); const at::Storage& storage = tensor.storage(); flatbuffers::Offset qschema_offset = 0; if (quantized) { double scale = 0; int64_t zero_point = 0; flatbuffers::Offset scales = 0; flatbuffers::Offset zero_points = 0; int64_t axis = 0; switch (tensor.qscheme()) { case at::kPerTensorAffine: scale = tensor.q_scale(); zero_point = tensor.q_zero_point(); break; case at::kPerChannelAffineFloatQParams: case at::kPerChannelAffine: { scales = tensorToFB(fbb, tensor.q_per_channel_scales()); zero_points = tensorToFB(fbb, tensor.q_per_channel_zero_points()); axis = tensor.q_per_channel_axis(); } break; default: TORCH_CHECK( false, "Unsupported tensor quantization type in serialization ", toString(tensor.qscheme())); break; } qschema_offset = mobile::serialization::CreateQuantizedSchema( fbb, static_cast(tensor.qscheme()), scale, static_cast(zero_point), scales, zero_points, static_cast(axis)); } void* addr = storage.unsafeGetStorageImpl(); uint32_t storage_index = 0; auto it = memoized_storage_map_.find(addr); if (it != memoized_storage_map_.end()) { storage_index = it->second; } else { storage_index = tensor_data_.size(); memoized_storage_map_[addr] = storage_index; tensor_data_.push_back(tensor); } std::vector sizes{tensor.sizes().begin(), tensor.sizes().end()}; std::vector strides{tensor.strides().begin(), tensor.strides().end()}; return CreateTensorMetadataDirect( fbb, /* storage_location_index */ storage_index, /* scalar_type */ static_cast(tensor.scalar_type()), /* int32_t storage_offset */ static_cast(tensor.storage_offset()), /* sizes */ &sizes, /* strides */ &strides, /* bool requires_grad */ tensor.requires_grad(), /* qschema */ qschema_offset); } uint32_t FlatbufferSerializer::storeIValueAndGetIndex( flatbuffers::FlatBufferBuilder& fbb, const IValue& ivalue) { if (ivalue.isNone()) { return kNoneIndex; } try { auto iter = cached_ivalues_.find(ivalue); if (iter != cached_ivalues_.end()) { return iter->second; } // NOLINTNEXTLINE(bugprone-empty-catch) } catch (...) { // Threw if ivalue is not hashable or // if ivalue is don't have proper operator== // we don't care catchall because either case we want to skip hashing } auto offset = iValueToFB(fbb, ivalue); uint32_t index = insertIValue(offset); try { cached_ivalues_[ivalue] = index; // NOLINTNEXTLINE(bugprone-empty-catch) } catch (...) { // Threw if ivalue is not hashable or // if ivalue is don't have proper operator== // we don't care catchall because either case we want to skip hashing } return index; } flatbuffers::Offset FlatbufferSerializer:: iValueToFB(flatbuffers::FlatBufferBuilder& fbb, const IValue& ivalue) { using mobile::serialization::IValueUnion; IValueUnion ivalue_type = IValueUnion::NONE; flatbuffers::Offset offset = 0; if (ivalue.isTensor()) { ivalue_type = IValueUnion::TensorMetadata; offset = tensorToFB(fbb, ivalue).Union(); } else if (ivalue.isTuple()) { ivalue_type = IValueUnion::Tuple; offset = tupleToFB(fbb, ivalue).Union(); } else if (ivalue.isDouble()) { ivalue_type = IValueUnion::Double; offset = fbb.CreateStruct(mobile::serialization::Double(ivalue.toDouble())) .Union(); } else if (ivalue.isComplexDouble()) { auto comp = ivalue.toComplexDouble(); ivalue_type = IValueUnion::ComplexDouble; offset = fbb.CreateStruct(mobile::serialization::ComplexDouble( comp.real(), comp.imag())) .Union(); } else if (ivalue.isInt()) { ivalue_type = IValueUnion::Int; offset = fbb.CreateStruct(mobile::serialization::Int(ivalue.toInt())).Union(); } else if (ivalue.isBool()) { ivalue_type = IValueUnion::Bool; offset = fbb.CreateStruct(mobile::serialization::Bool(ivalue.toBool())).Union(); } else if (ivalue.isString()) { ivalue_type = IValueUnion::String; offset = mobile::serialization::CreateString( fbb, fbb.CreateSharedString(ivalue.toStringRef())) .Union(); } else if (ivalue.isGenericDict()) { ivalue_type = IValueUnion::Dict; offset = dictToFB(fbb, ivalue).Union(); } else if (ivalue.isNone()) { ivalue_type = IValueUnion::NONE; offset = 0; } else if (ivalue.isIntList()) { ivalue_type = IValueUnion::IntList; offset = mobile::serialization::CreateIntList( fbb, fbb.CreateVector(ivalue.toIntVector())) .Union(); } else if (ivalue.isDoubleList()) { ivalue_type = IValueUnion::DoubleList; offset = mobile::serialization::CreateDoubleList( fbb, fbb.CreateVector(ivalue.toDoubleVector())) .Union(); } else if (ivalue.isBoolList()) { ivalue_type = IValueUnion::BoolList; auto boollist = ivalue.toBoolList(); std::vector bool_vec(boollist.begin(), boollist.end()); offset = mobile::serialization::CreateBoolListDirect(fbb, &bool_vec).Union(); } else if (ivalue.isList()) { ivalue_type = IValueUnion::List; offset = listToFB(fbb, ivalue).Union(); } else if (ivalue.isObject()) { ivalue_type = IValueUnion::Object; offset = objectToFB(fbb, ivalue).Union(); } else if (ivalue.isDevice()) { ivalue_type = IValueUnion::Device; offset = mobile::serialization::CreateDevice( fbb, fbb.CreateSharedString(ivalue.toDevice().str())) .Union(); } else if (ivalue.isEnum()) { const auto& enum_holder = ivalue.toEnumHolder(); const auto& qualified_class_name = enum_holder->type()->qualifiedClassName(); uint32_t ival_pos = storeIValueAndGetIndex(fbb, enum_holder->value()); ivalue_type = IValueUnion::EnumValue; offset = mobile::serialization::CreateEnumValue( fbb, fbb.CreateSharedString(qualified_class_name.qualifiedName()), ival_pos) .Union(); } else { TORCH_CHECK( false, "Invalid IValue type for serialization: ", ivalue.tagKind()); } return CreateIValue(fbb, ivalue_type, offset); } } // namespace void save_mobile_module( const mobile::Module& module, const std::string& filename, const ExtraFilesMap& extra_files, const ExtraFilesMap& jit_sources, const std::vector& jit_constants) { auto buffer = save_mobile_module_to_bytes( module, extra_files, jit_sources, jit_constants); std::fstream ofile(filename, std::ios::binary | std::ios::out); ofile.write( reinterpret_cast(buffer->data()), static_cast(buffer->size())); ofile.close(); } /// Deletes a DetachedBuffer, along with the internal /// flatbuffers::DetachedBuffer if present. Used as a custom deleter for /// std::unique_ptr; see UniqueDetachedBuffer and make_unique_detached_buffer. void DetachedBuffer::destroy(DetachedBuffer* buf) { // May be null. delete static_cast(buf->data_owner_); delete buf; } /// Provides access to DetachedBuffer::destroy(). struct DetachedBufferFriend { /// Returns a UniqueDetachedBuffer that wraps the provided DetachedBuffer. static DetachedBuffer::UniqueDetachedBuffer make_unique_detached_buffer( DetachedBuffer* buf) { return DetachedBuffer::UniqueDetachedBuffer(buf, DetachedBuffer::destroy); } }; DetachedBuffer::UniqueDetachedBuffer save_mobile_module_to_bytes( const mobile::Module& module, const ExtraFilesMap& extra_files, const ExtraFilesMap& jit_sources, const std::vector& jit_constants) { FlatbufferSerializer fb_serializer; flatbuffers::DetachedBuffer buf = fb_serializer.serializeModule( module, /*include_tensor_data_in_flatbuffer=*/true, extra_files, jit_sources, jit_constants); flatbuffers::DetachedBuffer* buf_ptr = new flatbuffers::DetachedBuffer(std::move(buf)); DetachedBuffer* ret = new DetachedBuffer(buf_ptr->data(), buf_ptr->size(), buf_ptr); return DetachedBufferFriend::make_unique_detached_buffer(ret); } void save_mobile_module_to_func( const mobile::Module& module, const std::function& writer_func) { auto buffer = save_mobile_module_to_bytes(module); writer_func(buffer->data(), buffer->size()); } bool register_flatbuffer_serializer() { return true; } } // namespace torch::jit