pytorch/caffe2/serialize/inline_container_test.cc
Lu Fang af6eea9391 Add the support of feature store example in pytorch model in fblearner (#20040)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20040

Add the support of feature store example in fblearner pytorch predictor, end to end

Reviewed By: dzhulgakov

Differential Revision: D15177897

fbshipit-source-id: 0f6df8b064eb9844fc9ddae61e978d6574c22916
2019-05-20 12:58:27 -07:00

66 lines
1.8 KiB
C++

#include <cstdio>
#include <string>
#include <array>
#include <gtest/gtest.h>
#include "caffe2/serialize/inline_container.h"
namespace caffe2 {
namespace serialize {
namespace {
TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
int64_t kFieldAlignment = 64L;
std::ostringstream oss;
// write records through writers
PyTorchStreamWriter writer(&oss);
std::array<char, 127> data1;
for (int i = 0; i < data1.size(); ++i) {
data1[i] = data1.size() - i;
}
writer.writeRecord("key1", data1.data(), data1.size());
std::array<char, 64> data2;
for (int i = 0; i < data2.size(); ++i) {
data2[i] = data2.size() - i;
}
writer.writeRecord("key2", data2.data(), data2.size());
writer.writeEndOfFile();
std::string the_file = oss.str();
std::ofstream foo("output.zip");
foo.write(the_file.c_str(), the_file.size());
foo.close();
std::istringstream iss(the_file);
// read records through readers
PyTorchStreamReader reader(&iss);
ASSERT_TRUE(reader.hasFile("key1"));
ASSERT_TRUE(reader.hasFile("key2"));
ASSERT_FALSE(reader.hasFile("key2000"));
at::DataPtr data_ptr;
int64_t size;
std::tie(data_ptr, size) = reader.getRecord("key1");
size_t off1 = reader.getRecordOffset("key1");
ASSERT_EQ(size, data1.size());
ASSERT_EQ(memcmp(data_ptr.get(), data1.data(), data1.size()), 0);
ASSERT_EQ(memcmp(the_file.c_str() + off1, data1.data(), data1.size()), 0);
ASSERT_EQ(off1 % kFieldAlignment, 0);
std::tie(data_ptr, size) = reader.getRecord("key2");
size_t off2 = reader.getRecordOffset("key2");
ASSERT_EQ(off2 % kFieldAlignment, 0);
ASSERT_EQ(size, data2.size());
ASSERT_EQ(memcmp(data_ptr.get(), data2.data(), data2.size()), 0);
ASSERT_EQ(memcmp(the_file.c_str() + off2, data2.data(), data2.size()), 0);
}
} // namespace
} // namespace serialize
} // namespace caffe2