#include #include #include #include #include namespace torch { namespace jit { using ::c10::IValue; // Protocol 2 is the highest that can be decoded by Python 2 // See https://docs.python.org/3/library/pickle.html#data-stream-format constexpr static uint8_t PROTOCOL_VERSION = 2; PicklerClass getClass(const std::string& str) { if (str == "build_tensor_from_id") { return PicklerClass::TENSOR; } else if (str == "build_intlist") { return PicklerClass::INTLIST; } else if (str == "build_tensorlist") { return PicklerClass::TENSORLIST; } else if (str == "build_doublelist") { return PicklerClass::DOUBLELIST; } else if (str == "build_boollist") { return PicklerClass::BOOLLIST; } // TODO [unpickler refactor] if (str == "TensorID") { return PicklerClass::TENSOR; } else if (str == "IntList") { return PicklerClass::INTLIST; } AT_ERROR("Unknown class name for unpickler: ", str); } const char* getClassName(PicklerClass cls) { switch (cls) { case PicklerClass::TENSOR: return "build_tensor_from_id"; case PicklerClass::INTLIST: return "build_intlist"; case PicklerClass::TENSORLIST: return "build_tensorlist"; case PicklerClass::DOUBLELIST: return "build_doublelist"; case PicklerClass::BOOLLIST: return "build_boollist"; default: AT_ERROR("Unknown class for pickler"); } } const std::vector& Pickler::stack() { return stack_; } void Pickler::protocol() { push(OpCode::PROTO); push(PROTOCOL_VERSION); } void Pickler::startTuple() { // All attributes get pushed into a tuple and their indices saved in the // module def push(OpCode::MARK); } void Pickler::endTuple() { push(OpCode::TUPLE); } void Pickler::stop() { push(OpCode::STOP); } void Pickler::torchSaveStop() { // Add the binary data for all the tensors to be included in the same binary // TODO: The pickler should be refactored to stream out to a stream directly // instead of staging in the stack_ array // As another pickle program in the same binary archive, add a list of // keys for each tensor (see torch/serialization.py) protocol(); push(OpCode::MARK); for (size_t i = 0; i < tensor_data_.size(); ++i) { std::string key = std::to_string(i); push(OpCode::BINUNICODE); push(key.size()); pushBytes(key); } push(OpCode::TUPLE); stop(); // Now dump the tensor binary data for (const auto& data : tensor_data_) { // first dump size push(data.numel()); stack_.insert(stack_.end(), data.data(), data.data() + data.sizeInBytes()); } } void Pickler::torchSaveStart() { // Output data to match torch.save, see torch/serialization.py for details // Magic number (0x1950a86a20f9469cfc6c) protocol(); push(OpCode::LONG1); // LONG1 size pushBytes("\x0a"); // LONG1 data pushBytes("\x6c\xfc\x9c\x46\xf9\x20\x6a\xa8\x50\x19"); stop(); // Protocol Version (1001) protocol(); push(OpCode::BININT2); pushBytes("\xe9\x03"); stop(); // sys_info, this isn't actually used in de-serialization so we can leave this // one empty protocol(); push(OpCode::EMPTY_DICT); stop(); } // unmemoized version called by pushIValue void Pickler::pushIValueImpl(const IValue& ivalue) { if (ivalue.isTensor()) { pushTensor(ivalue); } else if (ivalue.isTuple()) { pushTuple(ivalue); } else if (ivalue.isDouble()) { pushDouble(ivalue); } else if (ivalue.isInt()) { pushInt(ivalue); } else if (ivalue.isBool()) { if (ivalue.toBool()) { push(OpCode::NEWTRUE); } else { push(OpCode::NEWFALSE); } } else if (ivalue.isString()) { pushStringImpl(ivalue.toStringRef()); } else if (ivalue.isGenericList()) { pushGenericList(ivalue); } else if (ivalue.isGenericDict()) { pushDict(ivalue); } else if (ivalue.isNone()) { push(OpCode::NONE); } else if (ivalue.isIntList()) { pushSpecializedList( ivalue, PicklerClass::INTLIST, [=](const IValue& ivalue) { for (const int64_t item : ivalue.toIntListRef()) { pushIValue(item); } }); } else if (ivalue.isTensorList()) { pushSpecializedList( ivalue, PicklerClass::TENSORLIST, [=](const IValue& ivalue) { for (const at::Tensor& item : ivalue.toTensorListRef()) { pushIValue(item); } }); } else if (ivalue.isDoubleList()) { pushSpecializedList( ivalue, PicklerClass::DOUBLELIST, [=](const IValue& ivalue) { for (double item : ivalue.toDoubleListRef()) { pushIValue(item); } }); } else if (ivalue.isBoolList()) { pushSpecializedList( ivalue, PicklerClass::BOOLLIST, [=](const IValue& ivalue) { for (bool item : ivalue.toBoolList()) { pushIValue(item); } }); } else if (ivalue.isObject()) { auto obj = ivalue.toObject(); auto type = obj->type(); pushGlobal(type->qualifier(), type->basename()); push(OpCode::EMPTY_TUPLE); push(OpCode::NEWOBJ); if (checkHasValidSetGetState(type)) { Function* getstate = type->getMethod("__getstate__"); pushIValue((*getstate)({obj})); } else { push(OpCode::EMPTY_DICT); push(OpCode::MARK); for (size_t i = 0, n = type->numAttributes(); i < n; ++i) { pushString(type->getAttributeName(i)); pushIValue(obj->getSlot(i)); } push(OpCode::SETITEMS); } push(OpCode::BUILD); } else { AT_ERROR("Unknown IValue type for pickling: ", ivalue.tagKind()); } } void Pickler::pushIValue(const IValue& ivalue) { // Check if reference ivalue has been saved before if (ivalue.isPtrType()) { const void* ptr = ivalue.internalToPointer(); TORCH_CHECK( ptr != nullptr, "Pickler cannot memoize ", ivalue.tagKind(), " IValue ", ivalue); auto memo_entry = memoized_ivalue_map_.find(ptr); if (memo_entry != memoized_ivalue_map_.end()) { // This value has already been pushed, just do a BINGET pushBinGet(memo_entry->second); return; } } pushIValueImpl(ivalue); if (ivalue.isPtrType()) { memoized_ivalues_.push_back(ivalue); memoized_ivalue_map_[ivalue.internalToPointer()] = pushNextBinPut(); } } void Pickler::pushInt(const IValue& ivalue) { auto n = ivalue.toInt(); if (n >= std::numeric_limits::min() && n <= std::numeric_limits::max()) { push(OpCode::BININT1); push(n); } else if ( n >= std::numeric_limits::min() && n <= std::numeric_limits::max()) { push(OpCode::BININT); push(n); } else { // Push 8 byte integer push(OpCode::LONG1); push(8); push(n); } } void Pickler::pushBinGet(uint32_t memo_id) { if (memo_id <= std::numeric_limits::max()) { push(OpCode::BINGET); push(memo_id); } else { // Memoized too many items, issue a LONG_BINGET instead push(OpCode::LONG_BINGET); push(memo_id); } } // unmemoized encoding of a string void Pickler::pushStringImpl(const std::string& string) { push(OpCode::BINUNICODE); push(string.size()); pushBytes(string); } void Pickler::pushString(const std::string& string) { auto it = memoized_strings_map_.find(string); if (it == memoized_strings_map_.end()) { pushStringImpl(string); memoized_strings_map_[string] = pushNextBinPut(); } else { pushBinGet(it->second); } } void Pickler::pushStorageOfTensor(const at::Tensor& tensor) { const at::Storage& storage = tensor.storage(); void* addr = storage.unsafeGetStorageImpl(); auto it = memoized_storage_map_.find(addr); if (it != memoized_storage_map_.end()) { pushBinGet(it->second); return; } // Tuple for persistent_load push(OpCode::MARK); // typename pushString("storage"); // data_type std::stringstream data_type; data_type << toString(tensor.scalar_type()) << "Storage"; pushGlobal("torch", data_type.str()); // root_key pushString(std::to_string(tensor_data_.size())); // location pushString("cpu"); // size pushInt(tensor.numel()); // view_metadata push(OpCode::NONE); push(OpCode::TUPLE); push(OpCode::BINPERSID); memoized_storage_map_[addr] = pushNextBinPut(); tensor_data_.push_back(getWriteableTensorData(tensor)); } void Pickler::pushBytes(const std::string& string) { stack_.insert(stack_.end(), string.begin(), string.end()); } void Pickler::pushGlobal( const std::string& module_name, const std::string& class_name) { std::stringstream ss; ss << module_name << "\n" << class_name << "\n"; std::string key = ss.str(); auto memo_entry = memoized_globals_map_.find(key); if (memo_entry == memoized_globals_map_.end()) { push(OpCode::GLOBAL); pushBytes(key); // Push BINPUT without adding anything to the memoized_ivalues_ size_t memo_id = pushNextBinPut(); memoized_globals_map_.insert({key, memo_id}); } else { pushBinGet(memo_entry->second); } } void Pickler::pushTensor(const IValue& ivalue) { if (tensor_table_ == nullptr) { pushLiteralTensor(ivalue); } else { pushTensorReference(ivalue); } } void Pickler::pushLiteralTensor(const IValue& ivalue) { // In contrast to tensor references, literal tensors are included in the // pickle program binary blob. They are written to the file after the STOP // opcode. They can't be included in the pickle program itself without a bunch // of extra machinery since byte strings are limited to 4 GB. // // The format here is the same one used by `torch.save()`. The code for the // format can be found in `torch/serialization.py`. auto tensor = ivalue.toTensor(); // The arguments to this function are: // storage, storage_offset, size, stride, requires_grad, backward_hooks pushGlobal("torch._utils", "_rebuild_tensor_v2"); push(OpCode::MARK); pushStorageOfTensor(tensor); // storage offset int64_t storage_offset = 0; pushInt(storage_offset); // size push(OpCode::MARK); for (auto size : tensor.sizes()) { pushInt(size); } push(OpCode::TUPLE); // stride push(OpCode::MARK); for (auto stride : tensor.strides()) { pushInt(stride); } push(OpCode::TUPLE); // requires_grad pushIValue(tensor.requires_grad()); // backward_hooks pushGlobal("collections", "OrderedDict"); push(OpCode::EMPTY_TUPLE); // Construct the collections.OrderedDict for the backward_hooks push(OpCode::REDUCE); push(OpCode::TUPLE); // Call torch._utils._rebuild_tensor_v2 push(OpCode::REDUCE); } void Pickler::pushClass(PicklerClass cls) { pushGlobal("torch.jit._pickle", getClassName(cls)); } void Pickler::pushTensorReference(const IValue& ivalue) { pushClass(PicklerClass::TENSOR); tensor_table_->push_back(ivalue.toTensor()); int64_t tensor_id = tensor_table_->size() - 1; // Reduce arguments are spread (e.g. `*args`) before calling the global, // so wrap in a tuple push(OpCode::MARK); pushIValue(tensor_id); push(OpCode::TUPLE); push(OpCode::REDUCE); } void Pickler::pushSpecializedList( const IValue& ivalue, PicklerClass cls, const std::function& item_pusher) { pushClass(cls); // Reduce arguments are spread (e.g. `*args`) before calling the global, // so wrap in a tuple push(OpCode::MARK); push(OpCode::EMPTY_LIST); // Mark list push(OpCode::MARK); // Add all items item_pusher(ivalue); // Finish list push(OpCode::APPENDS); // Finish tuple push(OpCode::TUPLE); // Call reduce push(OpCode::REDUCE); } void Pickler::pushDouble(const IValue& ivalue) { double value = ivalue.toDouble(); AT_ASSERT(sizeof(double) == 8); char* bytes = reinterpret_cast(&value); push(OpCode::BINFLOAT); for (size_t i = 0; i < 8; ++i) { push(bytes[8 - i - 1]); } } void Pickler::pushDict(const IValue& ivalue) { push(OpCode::EMPTY_DICT); push(OpCode::MARK); // Sort the dict for deterministic keys auto dict_items = iterationOrder(ivalue.toGenericDict()); for (const auto& pair : dict_items) { pushIValue(pair.first); pushIValue(pair.second); } push(OpCode::SETITEMS); } size_t Pickler::pushNextBinPut() { if (memo_id_ <= std::numeric_limits::max()) { push(OpCode::BINPUT); push(memo_id_); } else { // Memoized too many items, issue a LONG_BINPUT instead push(OpCode::LONG_BINPUT); push(memo_id_); } AT_ASSERT(memo_id_ <= std::numeric_limits::max()); ++memo_id_; return memo_id_ - 1; } void Pickler::pushGenericList(const IValue& ivalue) { auto list = ivalue.toGenericListRef(); push(OpCode::EMPTY_LIST); push(OpCode::MARK); for (const IValue& item : list) { pushIValue(item); } push(OpCode::APPENDS); } void Pickler::pushTuple(const IValue& ivalue) { // TODO: Small tuple unrolling (e.g. TUPLE3) push(OpCode::MARK); auto tuple = ivalue.toTuple(); for (const IValue& item : tuple->elements()) { pushIValue(item); } push(OpCode::TUPLE); } std::vector Unpickler::parse_ivalue_list() { run(); TORCH_CHECK( stack_.size() == 1, "Unpickler expected 1 element on the stack, but found ", stack_.size()); auto value = stack_[0]; if (value.isGenericList()) { // TODO [unpickler refactor] return value.toGenericListRef().vec(); } return value.toTuple()->elements(); } double Unpickler::readFloat() { AT_ASSERT(sizeof(double) == 8); AT_ASSERT(bytes_ + 8 < end_ptr_); double result; // Pickle floats are big endian, so reverse the bytes std::reverse_copy( reinterpret_cast(bytes_), reinterpret_cast(bytes_ + 8), reinterpret_cast(&result)); bytes_ += 8; return result; } void Unpickler::run() { // Expect a PROTO opcode and protocol number at the start of blob TORCH_CHECK( readOpCode() == OpCode::PROTO, "Expected PROTO opcode at the start" " of pickle archive"); uint8_t protocol = read(); TORCH_CHECK( protocol == 2, "Only Pickle protocol 2 is supported, found protocol = ", protocol); while (bytes_ < end_ptr_) { OpCode opcode = readInstruction(); if (opcode == OpCode::STOP) { return; } } AT_ERROR("Overran buffer while unpickling data, didn't find STOP opcode"); } void Unpickler::setInput(size_t memo_id) { AT_ASSERT(!stack_.empty()); if (memo_id >= memo_table_.size()) { memo_table_.insert( memo_table_.end(), memo_id - memo_table_.size(), IValue()); memo_table_.push_back(stack_.back()); } else { memo_table_[memo_id] = stack_.back(); } } // emplace_back on bool vectors does not exist on some systems // avoid it by calling push_back for bool template inline void append(std::vector& a, T&& e) { a.emplace_back(std::move(e)); } template <> inline void append(std::vector& a, bool&& e) { a.push_back(e); } template static IValue toSpecializedList(const IValue& generic) { auto ivalues = generic.toGenericListRef(); std::vector specialized; specialized.reserve(ivalues.size()); for (const IValue& iv : ivalues) { append(specialized, iv.to()); } return IValue(std::move(specialized)); } OpCode Unpickler::readInstruction() { auto opcode = readOpCode(); switch (opcode) { case OpCode::EMPTY_LIST: { stack_.emplace_back( c10::impl::GenericList(c10::impl::deprecatedUntypedList())); } break; case OpCode::EMPTY_TUPLE: { if (empty_tuple_.isNone()) { // we only need one object, since tuples are not mutable. empty_tuple_ = c10::ivalue::Tuple::create({}); } stack_.emplace_back(empty_tuple_); } break; case OpCode::BINPUT: { size_t memo_id = read(); setInput(memo_id); } break; case OpCode::LONG_BINPUT: { TORCH_CHECK( std::numeric_limits::max() >= std::numeric_limits::max(), "Found a LONG_BINPUT opcode, but size_t on this system is " "not big enough to decode it"); size_t memo_id = read(); setInput(memo_id); } break; case OpCode::MARK: { // Mark location of the container ivalue in the stack marks_.push_back(stack_.size()); } break; case OpCode::NEWTRUE: { stack_.emplace_back(true); } break; case OpCode::NEWFALSE: { stack_.emplace_back(false); } break; case OpCode::NONE: { stack_.emplace_back(IValue()); } break; case OpCode::BININT1: { int8_t value = read(); stack_.emplace_back(int64_t(value)); } break; case OpCode::BININT: { int32_t value = read(); stack_.emplace_back(int64_t(value)); } break; case OpCode::LONG1: { // Only read LONG1s with 8 as the length uint8_t length = read(); AT_ASSERT(length == 8); stack_.emplace_back(int64_t(read())); } break; case OpCode::BINUNICODE: { uint32_t length = read(); const char* characters = reinterpret_cast(bytes_); AT_ASSERT(bytes_ + length < end_ptr_); bytes_ += length; stack_.emplace_back(std::string(characters, /*n=*/length)); } break; case OpCode::BINFLOAT: stack_.emplace_back(readFloat()); break; case OpCode::TUPLE: { size_t start = marks_.back(); marks_.pop_back(); auto tuple = c10::ivalue::Tuple::create({}); tuple->elements().reserve(stack_.size() - start); auto start_it = stack_.begin() + start; for (auto it = start_it; it != stack_.end(); ++it) { tuple->elements().emplace_back(*it); } stack_.erase(start_it, stack_.end()); stack_.emplace_back(tuple); } break; case OpCode::EMPTY_DICT: stack_.emplace_back(c10::impl::GenericDict(c10::impl::deprecatedUntypedDict())); break; case OpCode::APPENDS: { readList(); } break; case OpCode::SETITEMS: { size_t start = marks_.back(); marks_.pop_back(); auto dict = stack_.at(start - 1).toGenericDict(); for (size_t i = start; i < stack_.size(); i += 2) { dict.insert_or_assign(stack_[i], stack_[i + 1]); } stack_.erase(stack_.begin() + start, stack_.end()); } break; case OpCode::BINGET: { stack_.push_back(memo_table_.at(read())); } break; case OpCode::LONG_BINGET: { stack_.push_back(memo_table_.at(read())); } break; case OpCode::STOP: break; case OpCode::GLOBAL: { // Module name, it's not needed for anything auto module_name = readString(); auto class_name = readString(); // TODO [unpickler refactor] __main__ isn't used by the pickler anymore if (module_name == "__main__") { auto pickler_class = getClass(class_name); globals_.emplace_back([this, pickler_class] { // TODO: [unpickler refactor] auto setitem_data = stack_.back(); stack_.pop_back(); switch (pickler_class) { case PicklerClass::TENSOR: stack_.emplace_back(tensor_table_->at(setitem_data.toInt())); break; case PicklerClass::INTLIST: stack_.emplace_back(toSpecializedList(setitem_data)); break; default: AT_ERROR("Unknown pickler class id", pickler_class); } }); } else if (module_name == "torch.jit._pickle") { auto pickler_class = getClass(class_name); globals_.emplace_back([this, pickler_class] { // Pop reduce arg off the stack auto data = stack_.back().toTuple()->elements().at(0); stack_.pop_back(); switch (pickler_class) { case PicklerClass::TENSOR: stack_.emplace_back(tensor_table_->at(data.toInt())); break; case PicklerClass::INTLIST: stack_.emplace_back(toSpecializedList(data)); break; case PicklerClass::TENSORLIST: stack_.emplace_back(toSpecializedList(data)); break; case PicklerClass::DOUBLELIST: stack_.emplace_back(toSpecializedList(data)); break; case PicklerClass::BOOLLIST: stack_.emplace_back(toSpecializedList(data)); break; default: AT_ERROR("Unknown pickler class id"); } }); } else { AT_ASSERT(class_resolver_); at::StrongTypePtr type = class_resolver_(c10::QualifiedName(module_name, class_name)); auto cls = type.type_->expect(); size_t n = cls->numAttributes(); if (checkHasValidSetGetState(type.type_)) { globals_.emplace_back([this, type, n] { auto arg = std::move(stack_.back()); stack_.pop_back(); auto obj = c10::ivalue::Object::create(type, n); (*type.type_->getMethod("__setstate__"))({obj, arg}); stack_.emplace_back(std::move(obj)); }); } else { globals_.emplace_back([this, type, cls, n] { auto dict = std::move(stack_.back()).toGenericDict(); stack_.pop_back(); auto obj = c10::ivalue::Object::create(type, n); for (size_t i = 0; i < n; ++i) { obj->setSlot(i, dict.at(cls->getAttributeName(i))); } stack_.emplace_back(std::move(obj)); }); } } stack_.emplace_back(int64_t(globals_.size() - 1)); } break; case OpCode::NEWOBJ: { // pop empty tuple, the actual action is stored in the globals_stack_ stack_.pop_back(); } break; // because we have NEWOBJ do nothing, BUILD and REDUCE end up doing // the same thing case OpCode::BUILD: case OpCode::REDUCE: { // stack is: // extract and remove from the stack: std::swap(*(stack_.end() - 2), *(stack_.end() - 1)); size_t idx = stack_.back().toInt(); stack_.pop_back(); // stack is: globals_.at(idx)(); } break; default: AT_ERROR( "Unknown opcode for unpickling at ", reinterpret_cast(opcode), ": ", static_cast(opcode)); } return opcode; } // Pop all the list items off of the stack and append them to the list at the // corresponding MARK void Unpickler::readList() { size_t start = marks_.back(); marks_.pop_back(); auto list_ivalue = stack_.at(start - 1); auto num_elements = stack_.size() - start; auto elements = at::ArrayRef(stack_).slice(start); if (list_ivalue.isIntList()) { auto list = std::move(list_ivalue).toIntList(); list.reserve(num_elements); for (const auto& elem : elements) { list.emplace_back(elem.toInt()); } } else if (list_ivalue.isTensorList()) { auto list = std::move(list_ivalue).toTensorList(); list.reserve(num_elements); for (const auto& elem : elements) { list.emplace_back(elem.toTensor()); } } else if (list_ivalue.isDoubleList()) { auto list = std::move(list_ivalue).toDoubleList(); list.reserve(num_elements); for (const auto& elem : elements) { list.emplace_back(elem.toDouble()); } } else if (list_ivalue.isBoolList()) { auto list = std::move(list_ivalue).toBoolList(); list.reserve(num_elements); for (const auto& elem : elements) { list.push_back(elem.toBool()); } } else if (list_ivalue.isGenericList()) { auto list = std::move(list_ivalue).toGenericList(); list.reserve(num_elements); for (const auto& elem : elements) { list.emplace_back(elem); } } else { AT_ERROR("Unknown IValue list kind: ", list_ivalue.tagKind()); } stack_.erase(stack_.begin() + start, stack_.end()); } inline bool is_valid_python_id_char(char c) { return c == '_' || c == '.' || (c >= '0' && c <= '9') || (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z'); } // Read a newline terminated string std::string Unpickler::readString() { const char* chars = reinterpret_cast(bytes_); const char* char_end_ptr = reinterpret_cast(end_ptr_); size_t n = 0; while (true) { char c = chars[n]; if (c == '\n') { break; } // Simple check just in case there is no terminating '\n' TORCH_CHECK( is_valid_python_id_char(c), "Found character '", uint8_t(c), "' in string, " "strings must be qualified Python identifiers"); // Increment after to exclude newline from string ++n; TORCH_CHECK( chars + n < char_end_ptr, "Unpickler overran buffer while reading a string (expected a newline)"); } // Increment by string length + newline char bytes_ += n + 1; return std::string(chars, n); } OpCode Unpickler::readOpCode() { return static_cast(read()); } WriteableTensorData getWriteableTensorData(const at::Tensor& tensor) { WriteableTensorData result; result.tensor_ = tensor; result.size_ = tensor.element_size() * tensor.storage().size(); // TODO HIP support if (tensor.storage().device_type() == at::DeviceType::CUDA) { // NB: This new tensor is created to support cuda tensors. // Storages can be mutated when converting tensors from cuda to cpu, // and we need a cpu tensor to copy data from. result.tensor_ = at::empty({0}, tensor.options()) .set_( tensor.storage(), /* storage_offset = */ 0, /* size = */ {static_cast(tensor.storage().size())}, /* stride = */ {1}) .cpu(); TORCH_CHECK( result.tensor_.element_size() * result.tensor_.storage().size() == result.size_, "Storage tensor size did not match record size"); } return result; } bool checkHasValidSetGetState(const std::shared_ptr& cls) { // Check that the schemas for __getstate__ and __setstate__ are correct auto getstate = cls->getMethod("__getstate__"); if (getstate == nullptr) { return false; } auto get_schema = getstate->getSchema(); // Check __getstate__ // __getstate__ is expected to be (self) -> T TORCH_CHECK( get_schema.arguments().size() == 1, "'__getstate__' must have 'self' as its only argument, but found ", get_schema.arguments().size(), " arguments"); TORCH_CHECK( get_schema.returns().size() == 1, "'__getstate__' must return 1 value, but found ", get_schema.returns().size()); // Check __setstate__ if the method exists // __setstate__ is expected to be (self, T) -> None auto setstate = cls->getMethod("__setstate__"); if (!setstate) { return false; } auto set_schema = setstate->getSchema(); TORCH_CHECK( set_schema.arguments().size() == 2, "'__setstate__' must have 'self' and the state as its " "only arguments, but found ", set_schema.arguments().size(), " arguments"); TORCH_CHECK( set_schema.returns().size() == 1, "'__setstate__' must return None, but found ", set_schema.returns().size(), " return values"); TORCH_CHECK( set_schema.returns().at(0).type()->isSubtypeOf(NoneType::get()), "'__setstate__' must return None, but found value of type", set_schema.returns().at(0).type()->python_str()); // Check that the return type of __getstate__ matches the input to // __setstate__ auto get_type = get_schema.returns().at(0).type(); auto set_type = set_schema.arguments().at(1).type(); TORCH_CHECK( set_type->isSubtypeOf(get_type), "'__getstate__'s return type (", get_type->python_str(), " does not match '__setstate__'s argument type (", set_type->python_str(), "))"); return true; } } // namespace jit } // namespace torch