#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #ifndef DISABLE_UPGRADER #include #include #endif #if defined(HAVE_MMAP) #include #include #include #include #endif #ifdef _WIN32 #include #else #include #endif #include #include namespace torch { namespace jit { using caffe2::serialize::IStreamAdapter; using caffe2::serialize::PyTorchStreamReader; using caffe2::serialize::ReadAdapterInterface; static constexpr c10::string_view kCustomClassPrefix = "__torch__.torch.classes"; static constexpr c10::string_view kTorchPrefix = "__torch__"; static constexpr c10::string_view kJitPrefix = "torch.jit"; template std::vector parseListNative(const U* list) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(list != nullptr); return {list->items()->begin(), list->items()->end()}; } IValue parseList( FlatbufferLoader&, const mobile::serialization::IValue& ivalue); IValue parseTensor( FlatbufferLoader&, const mobile::serialization::IValue& ivalue); IValue parseTuple( FlatbufferLoader&, const mobile::serialization::IValue& ivalue); IValue parseDict( FlatbufferLoader&, const mobile::serialization::IValue& ivalue); IValue parseObject( FlatbufferLoader&, const mobile::serialization::IValue& ivalue); IValue parseIntList( FlatbufferLoader&, const mobile::serialization::IValue& ivalue); IValue parseDoubleList( FlatbufferLoader&, const mobile::serialization::IValue& ivalue); IValue parseBoolList( FlatbufferLoader&, const mobile::serialization::IValue& ivalue); IValue parseBasic( FlatbufferLoader&, const mobile::serialization::IValue& ivalue); IValue parseEnum( FlatbufferLoader&, const mobile::serialization::IValue& ivalue); TypePtr resolveType( const std::string& type_string, std::shared_ptr cu) { TypePtr type; c10::string_view type_str(type_string); if (type_str.starts_with(kCustomClassPrefix)) { type = getCustomClass(type_string); TORCH_CHECK( type, "The implementation of class ", type_string, " cannot be found."); } else if ( type_str.starts_with(kTorchPrefix) || type_str.starts_with(kJitPrefix)) { c10::QualifiedName qn(type_string); if (cu->get_class(qn) == nullptr) { auto classtype = ClassType::create(qn, cu, true); cu->register_type(classtype); type = classtype; } else { type = cu->get_class(qn); } } else { type = c10::parseType(type_string); } return type; } FlatbufferLoader::FlatbufferLoader() : mcu_(std::make_shared()), cu_(std::make_shared()), ivalue_parsers_{nullptr} { registerIValueParser(mobile::serialization::IValueUnion::NONE, &parseBasic); registerIValueParser(mobile::serialization::IValueUnion::Int, &parseBasic); registerIValueParser(mobile::serialization::IValueUnion::Bool, &parseBasic); registerIValueParser(mobile::serialization::IValueUnion::Double, &parseBasic); registerIValueParser( mobile::serialization::IValueUnion::ComplexDouble, &parseBasic); registerIValueParser( mobile::serialization::IValueUnion::TensorMetadata, &parseTensor); registerIValueParser(mobile::serialization::IValueUnion::String, &parseBasic); registerIValueParser(mobile::serialization::IValueUnion::List, &parseList); registerIValueParser( mobile::serialization::IValueUnion::IntList, &parseIntList); registerIValueParser( mobile::serialization::IValueUnion::DoubleList, &parseDoubleList); registerIValueParser( mobile::serialization::IValueUnion::BoolList, &parseBoolList); registerIValueParser(mobile::serialization::IValueUnion::Tuple, &parseTuple); registerIValueParser(mobile::serialization::IValueUnion::Dict, &parseDict); registerIValueParser( mobile::serialization::IValueUnion::Object, &parseObject); registerIValueParser(mobile::serialization::IValueUnion::Device, &parseBasic); registerIValueParser( mobile::serialization::IValueUnion::EnumValue, &parseEnum); internal_registerTypeResolver(&resolveType); } void FlatbufferLoader::registerIValueParser( mobile::serialization::IValueUnion ivalue_type, IValueParser parser) { ivalue_parsers_[static_cast(ivalue_type)] = parser; } void FlatbufferLoader::internal_registerTypeResolver( TypeResolver type_resolver) { type_resolver_ = type_resolver; } void parseExtraFilesFromVector( const flatbuffers::Vector>* files, ExtraFilesMap* extra_files) { for (uint32_t i = 0; i < files->size(); ++i) { const auto* extra_file = files->Get(i); (*extra_files)[extra_file->name()->str()] = extra_file->content()->str(); } } void parseExtraFiles( mobile::serialization::Module* module, ExtraFilesMap& extra_files) { auto extra_files_offsets = module->extra_files(); parseExtraFilesFromVector(extra_files_offsets, &extra_files); } void FlatbufferLoader::parseAndPopulate( uint32_t i, const mobile::serialization::IValue* ivalue) { if (const auto* func = ivalue->val_as_Function()) { auto func_ptr = parseFunction(func); all_functions_[i] = func_ptr.get(); mcu_->register_function(std::move(func_ptr)); } else { all_ivalues_[i] = parseIValue(ivalue); } } mobile::Module FlatbufferLoader::parseModule( mobile::serialization::Module* module) { module_ = module; all_ivalues_.clear(); all_types_.clear(); storages_.clear(); storage_loaded_.clear(); module_parsed_ = false; const auto* ivalues = module->ivalues(); all_ivalues_.resize(ivalues->size()); all_types_.resize(module->object_types()->size()); storages_.resize(module->storage_data_size()); storage_loaded_.resize(module->storage_data_size(), false); mobile_ivalue_size_ = module_->mobile_ivalue_size(); if (mobile_ivalue_size_ == 0) { mobile_ivalue_size_ = ivalues->size(); } for (uint32_t i = 0; i < mobile_ivalue_size_; i++) { const auto* ival = ivalues->Get(i); parseAndPopulate(i, ival); } IValue& module_ivalue = getIValue(module->state_obj()); // register functions for (const auto& f : all_functions_) { uint32_t class_index = ivalues->Get(f.first)->val_as_Function()->class_type(); ClassTypePtr class_type = all_types_[class_index]; class_type->addMethod(f.second); } module_parsed_ = true; auto m = mobile::Module(module_ivalue.toObject(), mcu_); m.set_min_operator_version(module->operator_version()); m.set_bytecode_version(module->bytecode_version()); return m; } namespace { void appendUpgraderFunctions(mobile::Function* function) { #ifndef DISABLE_UPGRADER for (auto& byteCodeFunctionWithOperator : getUpgraderBytecodeList()) { function->append_function(byteCodeFunctionWithOperator.function); } #endif } } // namespace std::unique_ptr FlatbufferLoader::parseFunction( const mobile::serialization::Function* method) { auto function = std::make_unique( c10::QualifiedName(method->qn()->str())); // TODO(qihan) add debug handle // const auto* debug_handle = method->debug_info()->debug_handle(); for (const auto* inst : *method->instructions()) { function->append_instruction( static_cast(inst->op()), inst->x(), inst->n()); } for (uint32_t i : *method->constants()) { function->append_constant(getIValue(i)); } appendUpgraderFunctions(function.get()); // 2. Decides if upgrader is needed const uint32_t operator_version = module_->operator_version(); bool use_upgrader = (operator_version < caffe2::serialize::kProducedFileFormatVersion); for (const auto* op : *method->operators()) { c10::optional num_args = c10::nullopt; if (op->num_args_serialized() > -1) { num_args = op->num_args_serialized(); } function->append_operator( op->name()->str(), op->overload_name()->str(), num_args); } if (should_load_operators_) { function->initialize_operators(true); } for (const auto i : *method->type_annotations()) { function->append_type(getOrCreateTypeAnnotations(i)); } // 3. If upgrader is needed, change change the OP instrunction to CALL // instruction (In next PR, use_upgrader will be parsed to parseInstruction // function and do the actual change) if (use_upgrader) { #ifndef DISABLE_UPGRADER applyUpgrader(function.get(), operator_version); #endif } function->set_register_size(method->register_size()); if (method->schema()) { try { auto parseArgList = [this](const auto* args_fb) { std::vector args; for (const auto* arg_tb : *args_fb) { IValue default_value = getIValue(arg_tb->default_value()); TypePtr type_ptr = getOrCreateTypeAnnotations(arg_tb->type()); auto arg = c10::Argument( arg_tb->name()->str(), std::move(type_ptr), c10::nullopt /*N*/, std::move(default_value)); args.emplace_back(std::move(arg)); } return args; }; c10::FunctionSchema schema( method->qn()->str(), "" /*overload_name*/, parseArgList(method->schema()->arguments()), parseArgList(method->schema()->returns()), false /*is_varargs*/, false /*is_varret*/); function->setSchema(std::move(schema)); } catch (const c10::Error& e) { } } return function; } IValue parseEnum( FlatbufferLoader& loader, const mobile::serialization::IValue& ivalue) { const auto* enum_val = ivalue.val_as_EnumValue(); auto enum_type = loader.getOrCreateTypeAnnotations(enum_val->type_name()) ->cast(); AT_ASSERT( enum_type, "Enum with type: " + enum_val->type_name()->str() + " not found."); IValue val = loader.getIValue(enum_val->value()); for (const auto& p : enum_type->enumNamesValues()) { if (p.second == val) { auto enum_holder = c10::make_intrusive( enum_type, p.first, p.second); return IValue(std::move(enum_holder)); } } AT_ASSERT( false, "Enum with type: " + enum_val->type_name()->str() + " not found."); } IValue parseBasic( FlatbufferLoader&, const mobile::serialization::IValue& ivalue) { switch (ivalue.val_type()) { case mobile::serialization::IValueUnion::NONE: return {}; case mobile::serialization::IValueUnion::Int: return ivalue.val_as_Int()->int_val(); case mobile::serialization::IValueUnion::Bool: return ivalue.val_as_Bool()->bool_val(); case mobile::serialization::IValueUnion::Double: return ivalue.val_as_Double()->double_val(); case mobile::serialization::IValueUnion::ComplexDouble: { const auto* comp = ivalue.val_as_ComplexDouble(); return c10::complex(comp->real(), comp->imag()); } case mobile::serialization::IValueUnion::String: return ivalue.val_as_String()->data()->str(); case mobile::serialization::IValueUnion::Device: { return c10::Device(ivalue.val_as_Device()->str()->str()); } default: return {}; } } at::Tensor parseTensorFromMetadata( FlatbufferLoader* loader, const mobile::serialization::TensorMetadata* tensor_md) { at::ScalarType type = static_cast(tensor_md->scalar_type()); auto options = at::CPU(type).options(); at::Tensor tensor; if (tensor_md->quantized_schema() != nullptr) { // is quantized const auto* schema = tensor_md->quantized_schema(); auto qscheme_type = static_cast(schema->qscheme()); switch (qscheme_type) { case at::kPerTensorAffine: { tensor = at::_empty_affine_quantized( {0}, options, schema->scale(), schema->zero_point()); } break; case at::kPerChannelAffineFloatQParams: case at::kPerChannelAffine: { at::Tensor scales = parseTensorFromMetadata(loader, schema->scales()); at::Tensor zero_points = parseTensorFromMetadata(loader, schema->zero_points()); tensor = at::_empty_per_channel_affine_quantized( {0}, scales, zero_points, schema->axis(), options); } break; default: TORCH_CHECK( false, "Unsupported tensor quantization type in serialization ", toString(qscheme_type)); break; } } else { tensor = at::empty({0}, options); } at::TensorImpl* impl = tensor.unsafeGetTensorImpl(); c10::Storage storage; storage = loader->getStorage(tensor_md->storage_location_index()); impl->set_storage_keep_dtype(storage); impl->set_storage_offset(tensor_md->storage_offset()); std::vector size{ tensor_md->sizes()->begin(), tensor_md->sizes()->end()}; std::vector stride{ tensor_md->strides()->begin(), tensor_md->strides()->end()}; impl->set_sizes_and_strides(size, stride); #ifndef MIN_EDGE_RUNTIME tensor = autograd::make_variable(tensor, tensor_md->requires_grad()); #endif return tensor; } IValue parseTensor( FlatbufferLoader& loader, const mobile::serialization::IValue& ivalue) { const mobile::serialization::TensorMetadata* tensor_md = ivalue.val_as_TensorMetadata(); return parseTensorFromMetadata(&loader, tensor_md); } IValue parseList( FlatbufferLoader& loader, const mobile::serialization::IValue& ivalue) { const mobile::serialization::List* list = ivalue.val_as_List(); auto res = c10::impl::GenericList(AnyType::get()); for (int i : *list->items()) { res.emplace_back(loader.getIValue(i)); } auto type = loader.getOrCreateTypeAnnotations(list->annotation_str()); res.unsafeSetElementType(type->containedType(0)); return res; } IValue parseIntList( FlatbufferLoader&, const mobile::serialization::IValue& ivalue) { const auto& list = ivalue.val_as_IntList(); return parseListNative(list); } IValue parseDoubleList( FlatbufferLoader&, const mobile::serialization::IValue& ivalue) { const auto& list = ivalue.val_as_DoubleList(); return parseListNative(list); } IValue parseBoolList( FlatbufferLoader&, const mobile::serialization::IValue& ivalue) { const auto& list = ivalue.val_as_BoolList(); std::vector res = parseListNative(list); c10::List boollist; for (auto x : res) { boollist.push_back(x); } return boollist; } IValue parseTuple( FlatbufferLoader& loader, const mobile::serialization::IValue& ivalue) { const auto& tuple = ivalue.val_as_Tuple(); std::vector res; for (int i : *tuple->items()) { res.emplace_back(loader.getIValue(i)); } return c10::ivalue::Tuple::create(res); } IValue parseDict( FlatbufferLoader& loader, const mobile::serialization::IValue& ivalue) { const auto* dict = ivalue.val_as_Dict(); auto result = c10::impl::GenericDict(AnyType::get(), AnyType::get()); const auto* keys = dict->keys(); const auto* values = dict->values(); for (size_t i = 0; i < keys->size(); ++i) { uint32_t key = keys->Get(i); uint32_t val = values->Get(i); result.insert_or_assign(loader.getIValue(key), loader.getIValue(val)); } auto type = loader.getOrCreateTypeAnnotations(dict->annotation_str()); result.unsafeSetKeyType(type->containedType(0)); result.unsafeSetValueType(type->containedType(1)); return result; } ClassTypePtr FlatbufferLoader::getOrCreateClassTypeForObject( const mobile::serialization::Object* object) { auto cls = getType(object->type_index()); const mobile::serialization::ObjectType* obj_type = module_->object_types()->Get(object->type_index()); if (cls == nullptr) { c10::string_view qn_str( obj_type->type_name()->c_str(), obj_type->type_name()->size()); if (qn_str.starts_with(kTorchPrefix) || qn_str.starts_with(kJitPrefix)) { c10::QualifiedName qn(obj_type->type_name()->str()); cls = cu_->get_class(qn); if (cls == nullptr) { cls = ClassType::create(qn, cu_, true); cu_->register_type(cls); } } else { cls = c10::parseType(std::string(qn_str))->cast(); } TORCH_CHECK(object->type_index() < all_ivalues_.size()); all_types_[object->type_index()] = cls; if (obj_type->type() == mobile::serialization::TypeType::CLASS_WITH_FIELD) { for (uint32_t i = 0; i < object->attrs()->size(); i++) { IValue val = getIValue(object->attrs()->Get(i)); // Need to use concrete object's field's type to set type of field. cls->addAttribute( obj_type->attr_names()->Get(i)->str(), val.type()); } } initialized_types_.insert(object->type_index()); } return cls; } IValue parseObject( FlatbufferLoader& loader, const mobile::serialization::IValue& ivalue) { const mobile::serialization::Object* object = ivalue.val_as_Object(); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(object != nullptr); const auto* cur_input = loader.getCurrentFlatbufferInput(); const mobile::serialization::ObjectType* obj_type = cur_input->object_types()->Get(object->type_index()); auto cls = loader.getOrCreateClassTypeForObject(object); Stack stack; switch (obj_type->type()) { case mobile::serialization::TypeType::CLASS_WITH_FIELD: { auto obj = c10::ivalue::Object::create( at::StrongTypePtr(loader.cu_, cls), object->attrs()->size()); for (uint32_t i = 0; i < object->attrs()->size(); i++) { IValue val = loader.getIValue(object->attrs()->Get(i)); obj->setSlot(i, std::move(val)); } return obj; } case mobile::serialization::TypeType::CLASS_WITH_SETSTATE: { IValue input = loader.getIValue(object->state()); mobile::Function* setstate = loader.getFunction(object->setstate_func()); auto obj = c10::ivalue::Object::create(at::StrongTypePtr(loader.cu_, cls), 0); stack.push_back(obj); stack.emplace_back(std::move(input)); setstate->run(stack); return obj; } case mobile::serialization::TypeType::CUSTOM_CLASS: { auto custom_class_type = torch::jit::getCustomClass(cls->name()->qualifiedName()); IValue input = loader.getIValue(object->state()); auto obj = c10::ivalue::Object::create( c10::StrongTypePtr(nullptr, custom_class_type), 1); stack.push_back(obj); stack.emplace_back(std::move(input)); custom_class_type->getMethod("__setstate__").run(stack); return obj; } default: AT_ASSERT(false, "need to be object"); } } IValue FlatbufferLoader::parseIValue( const mobile::serialization::IValue* ivalue) { return ivalue_parsers_[static_cast(ivalue->val_type())]( *this, *ivalue); } void deleteNothing2(void*); void deleteNothing2(void*) {} c10::Storage FlatbufferLoader::getStorage(uint32_t index) { TORCH_CHECK(index < storage_loaded_.size()); TORCH_CHECK(index < storages_.size()); if (!storage_loaded_[index]) { auto* storage = module_->storage_data()->GetMutableObject(index); size_t size = storage->data()->size(); at::DataPtr data; if (should_copy_tensor_memory_) { auto* allocator = at::GetCPUAllocator(); data = allocator->allocate(size); memcpy(data.get(), storage->data()->data(), size); } else { void* ptr = static_cast(storage->mutable_data()->data()); data = at::DataPtr(ptr, ptr, deleteNothing2, DeviceType::CPU); } storages_[index] = c10::Storage(c10::Storage::use_byte_size_t(), size, std::move(data)); storage_loaded_[index] = true; } return storages_[index]; } TypePtr FlatbufferLoader::getOrCreateTypeAnnotations( const flatbuffers::String* offset) { auto iter = type_annotations_.find(offset); if (iter != type_annotations_.end()) { return iter->second; } TypePtr type = type_resolver_(offset->str(), cu_); type_annotations_[offset] = type; return type; } void FlatbufferLoader::extractJitSourceAndConstants( ExtraFilesMap* jit_sources, std::vector* constants) { AT_ASSERT( module_parsed_, "Need to first parse a flatbuffer file before extracing jit_sources"); const auto* ivalues = module_->ivalues(); for (uint32_t i = mobile_ivalue_size_; i < ivalues->size(); i++) { const auto* ival = ivalues->Get(i); parseAndPopulate(i, ival); } // register functions for (const auto& f : all_functions_) { if (f.first >= mobile_ivalue_size_) { uint32_t class_index = ivalues->Get(f.first)->val_as_Function()->class_type(); ClassTypePtr class_type = all_types_[class_index]; class_type->addMethod(f.second); } } const auto* jit_constants = module_->jit_constants(); for (auto i = 0; i < jit_constants->size(); ++i) { constants->emplace_back(getIValue(jit_constants->Get(i))); } parseExtraFilesFromVector(module_->jit_sources(), jit_sources); } mobile::Module parse_and_initialize_mobile_module( std::shared_ptr data, size_t, c10::optional, ExtraFilesMap* extra_files) { TORCH_CHECK( mobile::serialization::ModuleBufferHasIdentifier(data.get()), "Format error"); auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get()); mobile::Module m = FlatbufferLoader().parseModule(flatbuffer_module); m.set_delete_memory(std::move(data)); if (extra_files != nullptr) { parseExtraFiles(flatbuffer_module, *extra_files); } return m; } mobile::Module initialize_mobile_module( mobile::serialization::Module* flatbuffer_module, c10::optional, bool should_copy_tensor_memory) { auto flatbufferLoader = FlatbufferLoader(); flatbufferLoader.setShouldCopyTensorMemory(should_copy_tensor_memory); mobile::Module m = flatbufferLoader.parseModule(flatbuffer_module); return m; } mobile::Module load_mobile_module_from_file( const std::string& filename, c10::optional device, ExtraFilesMap* extra_files) { std::shared_ptr data; size_t size = 0; std::tie(data, size) = get_file_content(filename.c_str()); return parse_and_initialize_mobile_module( std::move(data), size, device, extra_files); } uint64_t get_bytecode_version(std::istream& in) { std::shared_ptr data; size_t size = 0; std::tie(data, size) = get_stream_content(in); return get_bytecode_version_from_bytes(data.get()); } uint64_t get_bytecode_version(const std::string& filename) { std::shared_ptr data; size_t size = 0; std::tie(data, size) = get_file_content(filename.c_str()); return get_bytecode_version_from_bytes(data.get()); } uint64_t get_bytecode_version_from_bytes(char* flatbuffer_content) { TORCH_CHECK( mobile::serialization::ModuleBufferHasIdentifier(flatbuffer_content), "Format error"); auto* flatbuffer_module = mobile::serialization::GetMutableModule(flatbuffer_content); return flatbuffer_module->bytecode_version(); } mobile::ModuleInfo get_module_info_from_flatbuffer(char* flatbuffer_content) { auto* ff_module = mobile::serialization::GetMutableModule(flatbuffer_content); FlatbufferLoader loader; loader.setShouldLoadOperators(false); mobile::Module m = loader.parseModule(ff_module); return mobile::get_module_info(m); } mobile::Module load_mobile_module_from_stream_with_copy( std::istream& in, c10::optional device, ExtraFilesMap* extra_files) { std::shared_ptr data; size_t size = 0; std::tie(data, size) = get_stream_content(in); return parse_and_initialize_mobile_module( std::move(data), size, device, extra_files); } static mobile::Module parse_flatbuffer_no_object( std::shared_ptr data, size_t size, c10::optional device) { (void)device; (void)size; auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get()); FlatbufferLoader loader; // replace parserObject with to handle only class with field case // function. loader.registerIValueParser( mobile::serialization::IValueUnion::Object, +[](FlatbufferLoader& loader, const mobile::serialization::IValue& ivalue) { const mobile::serialization::Object* object = ivalue.val_as_Object(); auto cls = loader.getOrCreateClassTypeForObject(object); auto obj = c10::ivalue::Object::create( at::StrongTypePtr(loader.cu_, cls), object->attrs()->size()); for (uint32_t i = 0; i < object->attrs()->size(); i++) { IValue val = loader.getIValue(object->attrs()->Get(i)); obj->setSlot(i, std::move(val)); } return static_cast(obj); }); mobile::Module m = loader.parseModule(flatbuffer_module); m.set_delete_memory(std::move(data)); return m; } bool register_flatbuffer_loader() { load_flatbuffer_bytes = parse_and_initialize_mobile_module; load_flatbuffer_bytes_no_object = parse_flatbuffer_no_object; get_flatbuffer_bytecode_version = get_bytecode_version_from_bytes; return true; } const bool kRegisteredFlatbufferLoader = register_flatbuffer_loader(); } // namespace jit } // namespace torch