Allow serialization of custom types inside Tensor

Summary:
The use case is that sometimes we need a Tensor of custom type instead of POD
or string. This diff allows one to delegate to BlobSerializerBase to further
serialize the contents inside the Tensor.

Design choices:
(1) Each element is serialized as a BlobProto string, and stored in the
repeated string field.
(2) UNDEFINED is used as the enum value for the tensor data type, and the exact
type string is stored in the additional field.
(3) BlobSerializer is called on each item to obtain the serialized string.
(4) This requires the custom type to have copy constructor - otherwise it
will simply not be possible to copy over the deserialized content without
explicit type.

See blob_test.cc for an example.

Reviewed By: sunnieshang

Differential Revision: D6300196

fbshipit-source-id: 18bf94a22a07337e0fa83d3f1004b3651e38cf27
This commit is contained in:
Yangqing Jia 2017-11-10 13:03:58 -08:00 committed by Facebook Github Bot
parent c04ec84e1a
commit fc8532c89d
3 changed files with 99 additions and 7 deletions

View File

@ -416,10 +416,16 @@ void TensorSerializer<Context>::Serialize(
proto.mutable_double_data(),
&this->context_);
break;
case TensorProto_DataType_UNDEFINED:
LOG(FATAL) << "TensorSerializer does not have a serialization "
"implementation for " << input.meta().name();
break;
case TensorProto_DataType_UNDEFINED: {
proto.mutable_string_data()->Reserve(chunkSize);
Blob temp_blob;
const char* raw_data = static_cast<const char*>(input.raw_data());
for (int i = chunkBegin; i < chunkBegin + chunkSize; ++i) {
temp_blob.ShareExternal(
const_cast<char*>(raw_data + i * input.itemsize()), input.meta());
proto.add_string_data(temp_blob.Serialize(""));
}
} break;
// Note: we intentially do not provide "default:" so if any new data types
// are added, the compiler should warn the user to add the case here.
}
@ -572,8 +578,21 @@ void TensorDeserializer<Context>::Deserialize(
tensor->template mutable_data<double>() + chunkBegin,
&context);
break;
case TensorProto_DataType_UNDEFINED:
CAFFE_THROW("Cannot deserialize from a TensorProto UNDEFINED data type.");
case TensorProto_DataType_UNDEFINED: {
Blob temp_blob;
void* raw_ptr = nullptr;
for (int i = 0; i < chunkSize; ++i) {
temp_blob.Deserialize(proto.string_data(i));
if (i == 0) {
raw_ptr = tensor->template raw_mutable_data(temp_blob.meta());
}
temp_blob.meta().copy()(
temp_blob.GetRaw(),
static_cast<char*>(raw_ptr) +
(i + chunkBegin) * temp_blob.meta().itemsize(),
1);
}
}
}
context.FinishDeviceComputation();
}

View File

@ -41,13 +41,52 @@ CAFFE2_DECLARE_bool(caffe2_serialize_fp16_as_bytes);
namespace caffe2 {
using namespace ::caffe2::db;
namespace {
class BlobTestFoo {};
class BlobTestFoo {
public:
int32_t val;
};
class BlobTestBar {};
}
CAFFE_KNOWN_TYPE(BlobTestFoo);
CAFFE_KNOWN_TYPE(BlobTestBar);
class BlobTestFooSerializer : public BlobSerializerBase {
public:
BlobTestFooSerializer() {}
~BlobTestFooSerializer() {}
/**
* Serializes a Blob. Note that this blob has to contain Tensor<Context>,
* otherwise this function produces a fatal error.
*/
void Serialize(
const Blob& blob,
const string& name,
SerializationAcceptor acceptor) override {
CAFFE_ENFORCE(blob.IsType<BlobTestFoo>());
BlobProto blob_proto;
blob_proto.set_name(name);
blob_proto.set_type("BlobTestFoo");
// For simplicity we will just serialize the 4-byte content as a string.
blob_proto.set_content(std::string(
reinterpret_cast<const char*>(&(blob.Get<BlobTestFoo>().val)),
sizeof(int32_t)));
acceptor(name, blob_proto.SerializeAsString());
}
};
class BlobTestFooDeserializer : public BlobDeserializerBase {
public:
void Deserialize(const BlobProto& proto, Blob* blob) override {
blob->GetMutable<BlobTestFoo>()->val =
reinterpret_cast<const int32_t*>(proto.content().c_str())[0];
}
};
REGISTER_BLOB_SERIALIZER((TypeMeta::Id<BlobTestFoo>()), BlobTestFooSerializer);
REGISTER_BLOB_DESERIALIZER(BlobTestFoo, BlobTestFooDeserializer);
namespace {
TEST(BlobTest, Blob) {
@ -590,6 +629,32 @@ TEST_SERIALIZATION_WITH_TYPE(uint8_t, int32_data)
TEST_SERIALIZATION_WITH_TYPE(uint16_t, int32_data)
TEST_SERIALIZATION_WITH_TYPE(int64_t, int64_data)
TEST(TensorTest, TensorSerialization_CustomType) {
Blob blob;
TensorCPU* tensor = blob.GetMutable<TensorCPU>();
tensor->Resize(2, 3);
for (int i = 0; i < 6; ++i) {
tensor->mutable_data<BlobTestFoo>()[i].val = i;
}
string serialized = blob.Serialize("test");
BlobProto proto;
CHECK(proto.ParseFromString(serialized));
EXPECT_EQ(proto.name(), "test");
EXPECT_EQ(proto.type(), "Tensor");
Blob new_blob;
EXPECT_NO_THROW(new_blob.Deserialize(serialized));
EXPECT_TRUE(new_blob.IsType<TensorCPU>());
const TensorCPU& new_tensor = blob.Get<TensorCPU>();
EXPECT_EQ(new_tensor.ndim(), 2);
EXPECT_EQ(new_tensor.dim(0), 2);
EXPECT_EQ(new_tensor.dim(1), 3);
for (int i = 0; i < 6; ++i) {
EXPECT_EQ(
new_tensor.data<BlobTestFoo>()[i].val,
tensor->data<BlobTestFoo>()[i].val);
}
}
TEST(TensorTest, float16) {
const TIndex kSize = 3000000;
Blob blob;
@ -811,6 +876,7 @@ TYPED_TEST(TypedTensorTest, BigTensorSerialization) {
}
}
}
struct DummyType {
/* This struct is used to test serialization and deserialization of huge
* blobs, that are not tensors.

View File

@ -60,6 +60,13 @@ message TensorProto {
required int64 end = 2;
}
optional Segment segment = 11;
// Optionally, a TensorProto can contain a custom type beyond the above
// mentioned ones. In this case, each item of the tensor is going to be stored
// as a serialized string in the repeated string field. During deserialization
// BlobDeserializer is called with the type name specified here. Note that
// this is not a performant path and one is strongly encouraged to use POD
// types if possible.
optional string custom_data_type = 12;
}
message QTensorProto {