#include #include #include #include #include "caffe2/serialize/inline_container.h" namespace caffe2 { namespace serialize { namespace { TEST(PyTorchStreamWriterAndReader, SaveAndLoad) { int64_t kFieldAlignment = 64L; std::ostringstream oss; // write records through writers PyTorchStreamWriter writer(&oss); std::array data1; for (int i = 0; i < data1.size(); ++i) { data1[i] = data1.size() - i; } writer.writeRecord("key1", data1.data(), data1.size()); std::array data2; for (int i = 0; i < data2.size(); ++i) { data2[i] = data2.size() - i; } writer.writeRecord("key2", data2.data(), data2.size()); writer.writeEndOfFile(); std::string the_file = oss.str(); std::ofstream foo("output.zip"); foo.write(the_file.c_str(), the_file.size()); foo.close(); std::istringstream iss(the_file); // read records through readers PyTorchStreamReader reader(&iss); ASSERT_TRUE(reader.hasFile("key1")); ASSERT_TRUE(reader.hasFile("key2")); ASSERT_FALSE(reader.hasFile("key2000")); at::DataPtr data_ptr; int64_t size; std::tie(data_ptr, size) = reader.getRecord("key1"); size_t off1 = reader.getRecordOffset("key1"); ASSERT_EQ(size, data1.size()); ASSERT_EQ(memcmp(data_ptr.get(), data1.data(), data1.size()), 0); ASSERT_EQ(memcmp(the_file.c_str() + off1, data1.data(), data1.size()), 0); ASSERT_EQ(off1 % kFieldAlignment, 0); std::tie(data_ptr, size) = reader.getRecord("key2"); size_t off2 = reader.getRecordOffset("key2"); ASSERT_EQ(off2 % kFieldAlignment, 0); ASSERT_EQ(size, data2.size()); ASSERT_EQ(memcmp(data_ptr.get(), data2.data(), data2.size()), 0); ASSERT_EQ(memcmp(the_file.c_str() + off2, data2.data(), data2.size()), 0); } } // namespace } // namespace serialize } // namespace caffe2