#include "caffe2/core/blob_serialization.h" #include #include #include "caffe2/core/blob.h" CAFFE2_DEFINE_int( caffe2_tensor_chunk_size, 1000000, "Chunk size to split tensor data into"); namespace caffe2 { namespace { /** * @brief StringSerializer is the serializer for String. * * StringSerializer takes in a blob that contains a String, and serializes it * into a BlobProto protocol buffer. */ class StringSerializer : public BlobSerializerBase { public: StringSerializer() {} ~StringSerializer() {} /** * Serializes a Blob. Note that this blob has to contain Tensor, * otherwise this function produces a fatal error. */ void Serialize( const Blob& blob, const string& name, SerializationAcceptor acceptor) override { CHECK(blob.IsType()); BlobProto blob_proto; blob_proto.set_name(name); blob_proto.set_type("std::string"); blob_proto.set_content(blob.template Get()); acceptor(name, blob_proto.SerializeAsString()); } }; /** * @brief StringDeserializer is the deserializer for Strings. * */ class StringDeserializer : public BlobDeserializerBase { public: bool Deserialize(const BlobProto& proto, Blob* blob) override { *blob->GetMutable() = proto.content(); return true; } }; } namespace { // We can't use DeviceType_Name because of a protobuf-lite constraint. std::string tensorDeviceTypeName(const DeviceType& d) { switch (d) { case CPU: return "TensorCPU"; case CUDA: return "TensorCUDA"; default: CAFFE_THROW("Unknown device: ", d); return ""; } }; } // The blob serialization member function implementation. void Blob::Serialize( const string& name, BlobSerializerBase::SerializationAcceptor acceptor) const { std::unique_ptr serializer(CreateSerializer(meta_.id())); CAFFE_ENFORCE(serializer, "No known serializer for ", meta_.name()); serializer->Serialize(*this, name, acceptor); } // The blob serialization member function implementation. std::string Blob::Serialize(const string& name) const { std::stringstream data; std::mutex mutex; BlobSerializerBase::SerializationAcceptor acceptor = [&data, &mutex](const std::string& name, const std::string& blob) { std::lock_guard guard(mutex); data << blob; }; this->Serialize(name, acceptor); return data.str(); } // Specialization for StoreDeviceDetail for CPU - nothing needs to be done. template <> void TensorSerializer::StoreDeviceDetail( const Tensor& input, TensorProto* proto) {} // The actual serialization registry objects. CAFFE_DEFINE_TYPED_REGISTRY( BlobSerializerRegistry, CaffeTypeId, BlobSerializerBase); CAFFE_DEFINE_REGISTRY(BlobDeserializerRegistry, BlobDeserializerBase); bool Blob::Deserialize(const string& content) { BlobProto blob_proto; if (!blob_proto.ParseFromString(content)) { LOG(ERROR) << "Cannot parse content into a BlobProto."; return false; } return Deserialize(blob_proto); } bool Blob::Deserialize(const BlobProto& blob_proto) { if (blob_proto.has_tensor()) { // This is a tensor object. Depending on the device type, we will // use the corresponding TensorDeserializer. auto deserializer = CreateDeserializer(tensorDeviceTypeName( blob_proto.tensor().device_detail().device_type())); // Tensor's deserializer should always be registered, but we will double // check if it is not null anyway. return CHECK_NOTNULL(deserializer.get())->Deserialize(blob_proto, this); } else { auto deserializer = CreateDeserializer(blob_proto.type()); if (!deserializer.get()) { LOG(ERROR) << "No registered deserializer for type " << blob_proto.type(); return false; } return deserializer->Deserialize(blob_proto, this); } } namespace { // Serialize TensorCPU. REGISTER_BLOB_SERIALIZER( (TypeMeta::Id()), TensorSerializer); REGISTER_BLOB_DESERIALIZER(TensorCPU, TensorDeserializer); // Serialize std::string REGISTER_BLOB_SERIALIZER((TypeMeta::Id()), StringSerializer); REGISTER_BLOB_DESERIALIZER(std::string, StringDeserializer); } // namespace } // namespace caffe2