mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Allow serialization of custom types inside Tensor
Summary: The use case is that sometimes we need a Tensor of custom type instead of POD or string. This diff allows one to delegate to BlobSerializerBase to further serialize the contents inside the Tensor. Design choices: (1) Each element is serialized as a BlobProto string, and stored in the repeated string field. (2) UNDEFINED is used as the enum value for the tensor data type, and the exact type string is stored in the additional field. (3) BlobSerializer is called on each item to obtain the serialized string. (4) This requires the custom type to have copy constructor - otherwise it will simply not be possible to copy over the deserialized content without explicit type. See blob_test.cc for an example. Reviewed By: sunnieshang Differential Revision: D6300196 fbshipit-source-id: 18bf94a22a07337e0fa83d3f1004b3651e38cf27
This commit is contained in:
parent
c04ec84e1a
commit
fc8532c89d
|
|
@ -416,10 +416,16 @@ void TensorSerializer<Context>::Serialize(
|
|||
proto.mutable_double_data(),
|
||||
&this->context_);
|
||||
break;
|
||||
case TensorProto_DataType_UNDEFINED:
|
||||
LOG(FATAL) << "TensorSerializer does not have a serialization "
|
||||
"implementation for " << input.meta().name();
|
||||
break;
|
||||
case TensorProto_DataType_UNDEFINED: {
|
||||
proto.mutable_string_data()->Reserve(chunkSize);
|
||||
Blob temp_blob;
|
||||
const char* raw_data = static_cast<const char*>(input.raw_data());
|
||||
for (int i = chunkBegin; i < chunkBegin + chunkSize; ++i) {
|
||||
temp_blob.ShareExternal(
|
||||
const_cast<char*>(raw_data + i * input.itemsize()), input.meta());
|
||||
proto.add_string_data(temp_blob.Serialize(""));
|
||||
}
|
||||
} break;
|
||||
// Note: we intentially do not provide "default:" so if any new data types
|
||||
// are added, the compiler should warn the user to add the case here.
|
||||
}
|
||||
|
|
@ -572,8 +578,21 @@ void TensorDeserializer<Context>::Deserialize(
|
|||
tensor->template mutable_data<double>() + chunkBegin,
|
||||
&context);
|
||||
break;
|
||||
case TensorProto_DataType_UNDEFINED:
|
||||
CAFFE_THROW("Cannot deserialize from a TensorProto UNDEFINED data type.");
|
||||
case TensorProto_DataType_UNDEFINED: {
|
||||
Blob temp_blob;
|
||||
void* raw_ptr = nullptr;
|
||||
for (int i = 0; i < chunkSize; ++i) {
|
||||
temp_blob.Deserialize(proto.string_data(i));
|
||||
if (i == 0) {
|
||||
raw_ptr = tensor->template raw_mutable_data(temp_blob.meta());
|
||||
}
|
||||
temp_blob.meta().copy()(
|
||||
temp_blob.GetRaw(),
|
||||
static_cast<char*>(raw_ptr) +
|
||||
(i + chunkBegin) * temp_blob.meta().itemsize(),
|
||||
1);
|
||||
}
|
||||
}
|
||||
}
|
||||
context.FinishDeviceComputation();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -41,13 +41,52 @@ CAFFE2_DECLARE_bool(caffe2_serialize_fp16_as_bytes);
|
|||
namespace caffe2 {
|
||||
using namespace ::caffe2::db;
|
||||
namespace {
|
||||
class BlobTestFoo {};
|
||||
class BlobTestFoo {
|
||||
public:
|
||||
int32_t val;
|
||||
};
|
||||
class BlobTestBar {};
|
||||
}
|
||||
|
||||
CAFFE_KNOWN_TYPE(BlobTestFoo);
|
||||
CAFFE_KNOWN_TYPE(BlobTestBar);
|
||||
|
||||
class BlobTestFooSerializer : public BlobSerializerBase {
|
||||
public:
|
||||
BlobTestFooSerializer() {}
|
||||
~BlobTestFooSerializer() {}
|
||||
/**
|
||||
* Serializes a Blob. Note that this blob has to contain Tensor<Context>,
|
||||
* otherwise this function produces a fatal error.
|
||||
*/
|
||||
void Serialize(
|
||||
const Blob& blob,
|
||||
const string& name,
|
||||
SerializationAcceptor acceptor) override {
|
||||
CAFFE_ENFORCE(blob.IsType<BlobTestFoo>());
|
||||
|
||||
BlobProto blob_proto;
|
||||
blob_proto.set_name(name);
|
||||
blob_proto.set_type("BlobTestFoo");
|
||||
// For simplicity we will just serialize the 4-byte content as a string.
|
||||
blob_proto.set_content(std::string(
|
||||
reinterpret_cast<const char*>(&(blob.Get<BlobTestFoo>().val)),
|
||||
sizeof(int32_t)));
|
||||
acceptor(name, blob_proto.SerializeAsString());
|
||||
}
|
||||
};
|
||||
|
||||
class BlobTestFooDeserializer : public BlobDeserializerBase {
|
||||
public:
|
||||
void Deserialize(const BlobProto& proto, Blob* blob) override {
|
||||
blob->GetMutable<BlobTestFoo>()->val =
|
||||
reinterpret_cast<const int32_t*>(proto.content().c_str())[0];
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_BLOB_SERIALIZER((TypeMeta::Id<BlobTestFoo>()), BlobTestFooSerializer);
|
||||
REGISTER_BLOB_DESERIALIZER(BlobTestFoo, BlobTestFooDeserializer);
|
||||
|
||||
namespace {
|
||||
|
||||
TEST(BlobTest, Blob) {
|
||||
|
|
@ -590,6 +629,32 @@ TEST_SERIALIZATION_WITH_TYPE(uint8_t, int32_data)
|
|||
TEST_SERIALIZATION_WITH_TYPE(uint16_t, int32_data)
|
||||
TEST_SERIALIZATION_WITH_TYPE(int64_t, int64_data)
|
||||
|
||||
TEST(TensorTest, TensorSerialization_CustomType) {
|
||||
Blob blob;
|
||||
TensorCPU* tensor = blob.GetMutable<TensorCPU>();
|
||||
tensor->Resize(2, 3);
|
||||
for (int i = 0; i < 6; ++i) {
|
||||
tensor->mutable_data<BlobTestFoo>()[i].val = i;
|
||||
}
|
||||
string serialized = blob.Serialize("test");
|
||||
BlobProto proto;
|
||||
CHECK(proto.ParseFromString(serialized));
|
||||
EXPECT_EQ(proto.name(), "test");
|
||||
EXPECT_EQ(proto.type(), "Tensor");
|
||||
Blob new_blob;
|
||||
EXPECT_NO_THROW(new_blob.Deserialize(serialized));
|
||||
EXPECT_TRUE(new_blob.IsType<TensorCPU>());
|
||||
const TensorCPU& new_tensor = blob.Get<TensorCPU>();
|
||||
EXPECT_EQ(new_tensor.ndim(), 2);
|
||||
EXPECT_EQ(new_tensor.dim(0), 2);
|
||||
EXPECT_EQ(new_tensor.dim(1), 3);
|
||||
for (int i = 0; i < 6; ++i) {
|
||||
EXPECT_EQ(
|
||||
new_tensor.data<BlobTestFoo>()[i].val,
|
||||
tensor->data<BlobTestFoo>()[i].val);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TensorTest, float16) {
|
||||
const TIndex kSize = 3000000;
|
||||
Blob blob;
|
||||
|
|
@ -811,6 +876,7 @@ TYPED_TEST(TypedTensorTest, BigTensorSerialization) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct DummyType {
|
||||
/* This struct is used to test serialization and deserialization of huge
|
||||
* blobs, that are not tensors.
|
||||
|
|
|
|||
|
|
@ -60,6 +60,13 @@ message TensorProto {
|
|||
required int64 end = 2;
|
||||
}
|
||||
optional Segment segment = 11;
|
||||
// Optionally, a TensorProto can contain a custom type beyond the above
|
||||
// mentioned ones. In this case, each item of the tensor is going to be stored
|
||||
// as a serialized string in the repeated string field. During deserialization
|
||||
// BlobDeserializer is called with the type name specified here. Note that
|
||||
// this is not a performant path and one is strongly encouraged to use POD
|
||||
// types if possible.
|
||||
optional string custom_data_type = 12;
|
||||
}
|
||||
|
||||
message QTensorProto {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user