mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Format caffe2/serialize (#141850)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/141850 Approved by: https://github.com/cpuhrsch
This commit is contained in:
parent
941da90e8a
commit
bffaddf9ea
|
|
@ -21,7 +21,8 @@ FileAdapter::RAIIFile::RAIIFile(const std::string& file_name) {
|
||||||
auto error_msg =
|
auto error_msg =
|
||||||
std::system_category().default_error_condition(old_errno).message();
|
std::system_category().default_error_condition(old_errno).message();
|
||||||
#endif
|
#endif
|
||||||
TORCH_CHECK(false,
|
TORCH_CHECK(
|
||||||
|
false,
|
||||||
"open file failed because of errno ",
|
"open file failed because of errno ",
|
||||||
old_errno,
|
old_errno,
|
||||||
" on fopen: ",
|
" on fopen: ",
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <c10/macros/Macros.h>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <c10/macros/Macros.h>
|
|
||||||
|
|
||||||
#include "caffe2/serialize/istream_adapter.h"
|
#include "caffe2/serialize/istream_adapter.h"
|
||||||
#include "caffe2/serialize/read_adapter_interface.h"
|
#include "caffe2/serialize/read_adapter_interface.h"
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <cstring>
|
|
||||||
#include <caffe2/serialize/read_adapter_interface.h>
|
#include <caffe2/serialize/read_adapter_interface.h>
|
||||||
|
#include <cstring>
|
||||||
|
|
||||||
namespace caffe2 {
|
namespace caffe2 {
|
||||||
namespace serialize {
|
namespace serialize {
|
||||||
|
|
@ -17,7 +16,7 @@ class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface {
|
||||||
|
|
||||||
size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
|
size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
|
||||||
const override {
|
const override {
|
||||||
(void) what;
|
(void)what;
|
||||||
memcpy(buf, (int8_t*)(data_) + pos, n);
|
memcpy(buf, (int8_t*)(data_) + pos, n);
|
||||||
return n;
|
return n;
|
||||||
}
|
}
|
||||||
|
|
@ -27,6 +26,5 @@ class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface {
|
||||||
off_t size_;
|
off_t size_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
} // namespace serialize
|
} // namespace serialize
|
||||||
} // namespace caffe2
|
} // namespace caffe2
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
#include <cstdio>
|
|
||||||
#include <cstring>
|
|
||||||
#include <cerrno>
|
|
||||||
#include <istream>
|
|
||||||
#include <ostream>
|
|
||||||
#include <fstream>
|
|
||||||
#include <algorithm>
|
|
||||||
#include <sstream>
|
|
||||||
#include <sys/stat.h>
|
#include <sys/stat.h>
|
||||||
#include <sys/types.h>
|
#include <sys/types.h>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cerrno>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstring>
|
||||||
|
#include <fstream>
|
||||||
|
#include <istream>
|
||||||
|
#include <ostream>
|
||||||
|
#include <sstream>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
|
|
||||||
#include <c10/core/Allocator.h>
|
#include <c10/core/Allocator.h>
|
||||||
|
|
@ -48,25 +48,27 @@ ChunkRecordIterator::~ChunkRecordIterator() {
|
||||||
mz_zip_reader_extract_iter_free(iter_->impl);
|
mz_zip_reader_extract_iter_free(iter_->impl);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t ChunkRecordIterator::next(void* buf){
|
size_t ChunkRecordIterator::next(void* buf) {
|
||||||
size_t want_size = std::min(chunkSize_, recordSize_ - offset_);
|
size_t want_size = std::min(chunkSize_, recordSize_ - offset_);
|
||||||
if (want_size == 0) {
|
if (want_size == 0) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
size_t read_size = mz_zip_reader_extract_iter_read(iter_->impl, buf, want_size);
|
size_t read_size =
|
||||||
|
mz_zip_reader_extract_iter_read(iter_->impl, buf, want_size);
|
||||||
TORCH_CHECK(read_size > 0, "Read bytes should be larger than 0");
|
TORCH_CHECK(read_size > 0, "Read bytes should be larger than 0");
|
||||||
offset_ += read_size;
|
offset_ += read_size;
|
||||||
return read_size;
|
return read_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t istream_read_func(void* pOpaque, mz_uint64 file_ofs, void* pBuf, size_t n) {
|
size_t
|
||||||
|
istream_read_func(void* pOpaque, mz_uint64 file_ofs, void* pBuf, size_t n) {
|
||||||
auto self = static_cast<PyTorchStreamReader*>(pOpaque);
|
auto self = static_cast<PyTorchStreamReader*>(pOpaque);
|
||||||
return self->read(file_ofs, static_cast<char*>(pBuf), n);
|
return self->read(file_ofs, static_cast<char*>(pBuf), n);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string basename(const std::string& name) {
|
static std::string basename(const std::string& name) {
|
||||||
size_t start = 0;
|
size_t start = 0;
|
||||||
for(size_t i = 0; i < name.size(); ++i) {
|
for (size_t i = 0; i < name.size(); ++i) {
|
||||||
if (name[i] == '\\' || name[i] == '/') {
|
if (name[i] == '\\' || name[i] == '/') {
|
||||||
start = i + 1;
|
start = i + 1;
|
||||||
}
|
}
|
||||||
|
|
@ -77,7 +79,7 @@ static std::string basename(const std::string& name) {
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t end = name.size();
|
size_t end = name.size();
|
||||||
for(size_t i = end; i > start; --i) {
|
for (size_t i = end; i > start; --i) {
|
||||||
if (name[i - 1] == '.') {
|
if (name[i - 1] == '.') {
|
||||||
end = i - 1;
|
end = i - 1;
|
||||||
break;
|
break;
|
||||||
|
|
@ -92,13 +94,13 @@ static std::string parentdir(const std::string& name) {
|
||||||
end = name.find_last_of('\\');
|
end = name.find_last_of('\\');
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef WIN32
|
#ifdef WIN32
|
||||||
if (end != std::string::npos && end > 1 && name[end - 1] == ':') {
|
if (end != std::string::npos && end > 1 && name[end - 1] == ':') {
|
||||||
// This is a Windows root directory, so include the slash in
|
// This is a Windows root directory, so include the slash in
|
||||||
// the parent directory
|
// the parent directory
|
||||||
end++;
|
end++;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
if (end == std::string::npos) {
|
if (end == std::string::npos) {
|
||||||
return "";
|
return "";
|
||||||
|
|
@ -157,8 +159,8 @@ void PyTorchStreamReader::init() {
|
||||||
mz_zip_reader_init(ar_.get(), size, 0);
|
mz_zip_reader_init(ar_.get(), size, 0);
|
||||||
valid("reading zip archive");
|
valid("reading zip archive");
|
||||||
|
|
||||||
// figure out the archive_name (i.e. the zip folder all the other files are in)
|
// figure out the archive_name (i.e. the zip folder all the other files are
|
||||||
// all lookups to getRecord will be prefixed by this folder
|
// in) all lookups to getRecord will be prefixed by this folder
|
||||||
mz_uint n = mz_zip_reader_get_num_files(ar_.get());
|
mz_uint n = mz_zip_reader_get_num_files(ar_.get());
|
||||||
if (n == 0) {
|
if (n == 0) {
|
||||||
CAFFE_THROW("archive does not contain any files");
|
CAFFE_THROW("archive does not contain any files");
|
||||||
|
|
@ -201,15 +203,15 @@ void PyTorchStreamReader::init() {
|
||||||
TORCH_CHECK(hasRecord("version"))
|
TORCH_CHECK(hasRecord("version"))
|
||||||
std::tie(version_ptr, version_size) = getRecord("version");
|
std::tie(version_ptr, version_size) = getRecord("version");
|
||||||
}
|
}
|
||||||
std::string version(static_cast<const char*>(version_ptr.get()), version_size);
|
std::string version(
|
||||||
|
static_cast<const char*>(version_ptr.get()), version_size);
|
||||||
try {
|
try {
|
||||||
version_ = std::stoull(version);
|
version_ = std::stoull(version);
|
||||||
} catch (const std::invalid_argument& e) {
|
} catch (const std::invalid_argument& e) {
|
||||||
CAFFE_THROW("Couldn't parse the version ",
|
CAFFE_THROW("Couldn't parse the version ", version, " as Long Long.");
|
||||||
version,
|
|
||||||
" as Long Long.");
|
|
||||||
}
|
}
|
||||||
if (version_ < static_cast<decltype(version_)>(kMinSupportedFileFormatVersion)) {
|
if (version_ <
|
||||||
|
static_cast<decltype(version_)>(kMinSupportedFileFormatVersion)) {
|
||||||
CAFFE_THROW(
|
CAFFE_THROW(
|
||||||
"Attempted to read a PyTorch file with version ",
|
"Attempted to read a PyTorch file with version ",
|
||||||
std::to_string(version_),
|
std::to_string(version_),
|
||||||
|
|
@ -219,7 +221,8 @@ void PyTorchStreamReader::init() {
|
||||||
" with latest version of PyTorch to mitigate this issue.");
|
" with latest version of PyTorch to mitigate this issue.");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (version_ > static_cast<decltype(version_)>(kMaxSupportedFileFormatVersion)) {
|
if (version_ >
|
||||||
|
static_cast<decltype(version_)>(kMaxSupportedFileFormatVersion)) {
|
||||||
CAFFE_THROW(
|
CAFFE_THROW(
|
||||||
"Attempted to read a PyTorch file with version ",
|
"Attempted to read a PyTorch file with version ",
|
||||||
version_,
|
version_,
|
||||||
|
|
@ -277,12 +280,13 @@ size_t getPadding(
|
||||||
padding_buf[3] = (uint8_t)(padding_size >> 8);
|
padding_buf[3] = (uint8_t)(padding_size >> 8);
|
||||||
return padding_size_plus_fbxx;
|
return padding_size_plus_fbxx;
|
||||||
}
|
}
|
||||||
}
|
} // namespace detail
|
||||||
|
|
||||||
bool PyTorchStreamReader::hasRecord(const std::string& name) {
|
bool PyTorchStreamReader::hasRecord(const std::string& name) {
|
||||||
std::lock_guard<std::mutex> guard(reader_lock_);
|
std::lock_guard<std::mutex> guard(reader_lock_);
|
||||||
|
|
||||||
if ((!load_debug_symbol_) && c10::ends_with(std::string_view(name), kDebugPklSuffix)) {
|
if ((!load_debug_symbol_) &&
|
||||||
|
c10::ends_with(std::string_view(name), kDebugPklSuffix)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
std::string ss = archive_name_plus_slash_ + name;
|
std::string ss = archive_name_plus_slash_ + name;
|
||||||
|
|
@ -307,7 +311,8 @@ std::vector<std::string> PyTorchStreamReader::getAllRecords() {
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||||
char buf[MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE];
|
char buf[MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE];
|
||||||
for (size_t i = 0; i < num_files; i++) {
|
for (size_t i = 0; i < num_files; i++) {
|
||||||
mz_zip_reader_get_filename(ar_.get(), i, buf, MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE);
|
mz_zip_reader_get_filename(
|
||||||
|
ar_.get(), i, buf, MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE);
|
||||||
if (strncmp(
|
if (strncmp(
|
||||||
buf,
|
buf,
|
||||||
archive_name_plus_slash_.data(),
|
archive_name_plus_slash_.data(),
|
||||||
|
|
@ -319,7 +324,9 @@ std::vector<std::string> PyTorchStreamReader::getAllRecords() {
|
||||||
buf);
|
buf);
|
||||||
}
|
}
|
||||||
if ((load_debug_symbol_) ||
|
if ((load_debug_symbol_) ||
|
||||||
(!c10::ends_with(std::string_view(buf + archive_name_plus_slash_.size()),kDebugPklSuffix))) {
|
(!c10::ends_with(
|
||||||
|
std::string_view(buf + archive_name_plus_slash_.size()),
|
||||||
|
kDebugPklSuffix))) {
|
||||||
// NOLINTNEXTLINE(modernize-use-emplace)
|
// NOLINTNEXTLINE(modernize-use-emplace)
|
||||||
out.push_back(buf + archive_name_plus_slash_.size());
|
out.push_back(buf + archive_name_plus_slash_.size());
|
||||||
}
|
}
|
||||||
|
|
@ -340,7 +347,8 @@ size_t PyTorchStreamReader::getRecordID(const std::string& name) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// return dataptr, size
|
// return dataptr, size
|
||||||
std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(const std::string& name) {
|
std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(
|
||||||
|
const std::string& name) {
|
||||||
std::lock_guard<std::mutex> guard(reader_lock_);
|
std::lock_guard<std::mutex> guard(reader_lock_);
|
||||||
if ((!load_debug_symbol_) && c10::ends_with(name, kDebugPklSuffix)) {
|
if ((!load_debug_symbol_) && c10::ends_with(name, kDebugPklSuffix)) {
|
||||||
at::DataPtr retval;
|
at::DataPtr retval;
|
||||||
|
|
@ -351,45 +359,57 @@ std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(const std::string
|
||||||
mz_zip_reader_file_stat(ar_.get(), key, &stat);
|
mz_zip_reader_file_stat(ar_.get(), key, &stat);
|
||||||
valid("retrieving file meta-data for ", name.c_str());
|
valid("retrieving file meta-data for ", name.c_str());
|
||||||
at::DataPtr retval = c10::GetCPUAllocator()->allocate(stat.m_uncomp_size);
|
at::DataPtr retval = c10::GetCPUAllocator()->allocate(stat.m_uncomp_size);
|
||||||
mz_zip_reader_extract_to_mem(ar_.get(), key, retval.get(), stat.m_uncomp_size, 0);
|
mz_zip_reader_extract_to_mem(
|
||||||
|
ar_.get(), key, retval.get(), stat.m_uncomp_size, 0);
|
||||||
valid("reading file ", name.c_str());
|
valid("reading file ", name.c_str());
|
||||||
|
|
||||||
return std::make_tuple(std::move(retval), stat.m_uncomp_size);
|
return std::make_tuple(std::move(retval), stat.m_uncomp_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t
|
size_t PyTorchStreamReader::getRecordMultiReaders(
|
||||||
PyTorchStreamReader::getRecordMultiReaders(const std::string& name,
|
const std::string& name,
|
||||||
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders,
|
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders,
|
||||||
void *dst, size_t n){
|
void* dst,
|
||||||
|
size_t n) {
|
||||||
size_t nthread = additionalReaders.size()+1;
|
size_t nthread = additionalReaders.size() + 1;
|
||||||
size_t recordOff = getRecordOffset(name);
|
size_t recordOff = getRecordOffset(name);
|
||||||
std::vector<std::thread> loaderThreads;
|
std::vector<std::thread> loaderThreads;
|
||||||
size_t perThreadSize = (n+nthread-1)/nthread;
|
size_t perThreadSize = (n + nthread - 1) / nthread;
|
||||||
std::vector<size_t> readSizes(nthread, 0);
|
std::vector<size_t> readSizes(nthread, 0);
|
||||||
std::lock_guard<std::mutex> guard(reader_lock_);
|
std::lock_guard<std::mutex> guard(reader_lock_);
|
||||||
for(size_t i = 0; i < nthread ; i++){
|
for (size_t i = 0; i < nthread; i++) {
|
||||||
loaderThreads.emplace_back([this, name, i, n, recordOff, perThreadSize, dst, &additionalReaders, &readSizes]{
|
loaderThreads.emplace_back([this,
|
||||||
size_t startPos = i*perThreadSize;
|
name,
|
||||||
size_t endPos = std::min((i+1)*perThreadSize,n);
|
i,
|
||||||
if (startPos < endPos){
|
n,
|
||||||
|
recordOff,
|
||||||
|
perThreadSize,
|
||||||
|
dst,
|
||||||
|
&additionalReaders,
|
||||||
|
&readSizes] {
|
||||||
|
size_t startPos = i * perThreadSize;
|
||||||
|
size_t endPos = std::min((i + 1) * perThreadSize, n);
|
||||||
|
if (startPos < endPos) {
|
||||||
size_t threadReadSize = endPos - startPos;
|
size_t threadReadSize = endPos - startPos;
|
||||||
size_t size = 0;
|
size_t size = 0;
|
||||||
if (i==0){
|
if (i == 0) {
|
||||||
size = read(recordOff+startPos, (char *)dst+startPos, threadReadSize);
|
size =
|
||||||
}else{
|
read(recordOff + startPos, (char*)dst + startPos, threadReadSize);
|
||||||
auto reader = additionalReaders[i-1];
|
} else {
|
||||||
size = reader->read(recordOff+startPos, (char *)dst+startPos, threadReadSize);
|
auto reader = additionalReaders[i - 1];
|
||||||
|
size = reader->read(
|
||||||
|
recordOff + startPos, (char*)dst + startPos, threadReadSize);
|
||||||
}
|
}
|
||||||
readSizes[i] = size;
|
readSizes[i] = size;
|
||||||
LOG(INFO) << "Thread " << i << " read [" << startPos << "-" << endPos << "] "
|
LOG(INFO) << "Thread " << i << " read [" << startPos << "-" << endPos
|
||||||
<< "from " << name << " of size " << n;
|
<< "] "
|
||||||
|
<< "from " << name << " of size " << n;
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
threadReadSize == size,
|
threadReadSize == size,
|
||||||
"record size ",
|
"record size ",
|
||||||
threadReadSize,
|
threadReadSize,
|
||||||
" mismatch with read size ",
|
" mismatch with read size ",
|
||||||
size);
|
size);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -400,7 +420,7 @@ PyTorchStreamReader::getRecordMultiReaders(const std::string& name,
|
||||||
loaderThreads.clear();
|
loaderThreads.clear();
|
||||||
|
|
||||||
size_t total_read_n = 0;
|
size_t total_read_n = 0;
|
||||||
for (auto& r : readSizes){
|
for (auto& r : readSizes) {
|
||||||
total_read_n += r;
|
total_read_n += r;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -415,10 +435,10 @@ PyTorchStreamReader::getRecordMultiReaders(const std::string& name,
|
||||||
}
|
}
|
||||||
|
|
||||||
// read record with multi clients
|
// read record with multi clients
|
||||||
std::tuple<at::DataPtr, size_t>
|
std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(
|
||||||
PyTorchStreamReader::getRecord(const std::string& name,
|
const std::string& name,
|
||||||
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders) {
|
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders) {
|
||||||
if(additionalReaders.empty()){
|
if (additionalReaders.empty()) {
|
||||||
// No additional readers or record too small, use single threaded version
|
// No additional readers or record too small, use single threaded version
|
||||||
return getRecord(name);
|
return getRecord(name);
|
||||||
}
|
}
|
||||||
|
|
@ -432,7 +452,7 @@ PyTorchStreamReader::getRecord(const std::string& name,
|
||||||
mz_zip_reader_file_stat(ar_.get(), key, &stat);
|
mz_zip_reader_file_stat(ar_.get(), key, &stat);
|
||||||
auto n = stat.m_uncomp_size;
|
auto n = stat.m_uncomp_size;
|
||||||
valid("retrieving file meta-data for ", name.c_str());
|
valid("retrieving file meta-data for ", name.c_str());
|
||||||
if(n < additional_reader_size_threshold_){
|
if (n < additional_reader_size_threshold_) {
|
||||||
// Reader size too small, use single threaded version
|
// Reader size too small, use single threaded version
|
||||||
return getRecord(name);
|
return getRecord(name);
|
||||||
}
|
}
|
||||||
|
|
@ -466,17 +486,20 @@ PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n) {
|
||||||
return stat.m_uncomp_size;
|
return stat.m_uncomp_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// inplace memory writing, in-tensor multi-threads, can be used for large
|
||||||
// inplace memory writing, in-tensor multi-threads, can be used for large tensor.
|
// tensor.
|
||||||
size_t
|
size_t PyTorchStreamReader::getRecord(
|
||||||
PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n,
|
const std::string& name,
|
||||||
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders) {
|
void* dst,
|
||||||
if(additionalReaders.empty()){
|
size_t n,
|
||||||
|
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders) {
|
||||||
|
if (additionalReaders.empty()) {
|
||||||
// No additional readers, use single threaded version
|
// No additional readers, use single threaded version
|
||||||
return getRecord(name, dst, n);
|
return getRecord(name, dst, n);
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((!load_debug_symbol_) && c10::ends_with(std::string_view(name), kDebugPklSuffix)) {
|
if ((!load_debug_symbol_) &&
|
||||||
|
c10::ends_with(std::string_view(name), kDebugPklSuffix)) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
size_t key = getRecordID(name);
|
size_t key = getRecordID(name);
|
||||||
|
|
@ -490,7 +513,7 @@ PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n,
|
||||||
n);
|
n);
|
||||||
valid("retrieving file meta-data for ", name.c_str());
|
valid("retrieving file meta-data for ", name.c_str());
|
||||||
|
|
||||||
if(n < additional_reader_size_threshold_){
|
if (n < additional_reader_size_threshold_) {
|
||||||
// Reader size too small, use single threaded version
|
// Reader size too small, use single threaded version
|
||||||
return getRecord(name, dst, n);
|
return getRecord(name, dst, n);
|
||||||
}
|
}
|
||||||
|
|
@ -577,7 +600,8 @@ size_t PyTorchStreamReader::getRecordOffset(const std::string& name) {
|
||||||
"reading file header");
|
"reading file header");
|
||||||
size_t filename_len = read_le_16(local_header + MZ_ZIP_LDH_FILENAME_LEN_OFS);
|
size_t filename_len = read_le_16(local_header + MZ_ZIP_LDH_FILENAME_LEN_OFS);
|
||||||
size_t extra_len = read_le_16(local_header + MZ_ZIP_LDH_EXTRA_LEN_OFS);
|
size_t extra_len = read_le_16(local_header + MZ_ZIP_LDH_EXTRA_LEN_OFS);
|
||||||
return stat.m_local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len + extra_len;
|
return stat.m_local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len +
|
||||||
|
extra_len;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t PyTorchStreamReader::getRecordSize(const std::string& name) {
|
size_t PyTorchStreamReader::getRecordSize(const std::string& name) {
|
||||||
|
|
@ -620,14 +644,16 @@ size_t ostream_write_func(
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
PyTorchStreamWriter::PyTorchStreamWriter(const std::string& file_name, bool compute_crc32)
|
PyTorchStreamWriter::PyTorchStreamWriter(
|
||||||
: archive_name_(basename(file_name)),
|
const std::string& file_name,
|
||||||
compute_crc32_(compute_crc32) {
|
bool compute_crc32)
|
||||||
|
: archive_name_(basename(file_name)), compute_crc32_(compute_crc32) {
|
||||||
setup(file_name);
|
setup(file_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
PyTorchStreamWriter::PyTorchStreamWriter(
|
PyTorchStreamWriter::PyTorchStreamWriter(
|
||||||
const std::function<size_t(const void*, size_t)> writer_func, bool compute_crc32)
|
const std::function<size_t(const void*, size_t)> writer_func,
|
||||||
|
bool compute_crc32)
|
||||||
: archive_name_("archive"),
|
: archive_name_("archive"),
|
||||||
writer_func_(writer_func),
|
writer_func_(writer_func),
|
||||||
compute_crc32_(compute_crc32) {
|
compute_crc32_(compute_crc32) {
|
||||||
|
|
@ -649,10 +675,12 @@ void PyTorchStreamWriter::setup(const string& file_name) {
|
||||||
valid("opening archive ", file_name.c_str());
|
valid("opening archive ", file_name.c_str());
|
||||||
|
|
||||||
const std::string dir_name = parentdir(file_name);
|
const std::string dir_name = parentdir(file_name);
|
||||||
if(!dir_name.empty()) {
|
if (!dir_name.empty()) {
|
||||||
struct stat st;
|
struct stat st;
|
||||||
bool dir_exists = (stat(dir_name.c_str(), &st) == 0 && (st.st_mode & S_IFDIR));
|
bool dir_exists =
|
||||||
TORCH_CHECK(dir_exists, "Parent directory ", dir_name, " does not exist.");
|
(stat(dir_name.c_str(), &st) == 0 && (st.st_mode & S_IFDIR));
|
||||||
|
TORCH_CHECK(
|
||||||
|
dir_exists, "Parent directory ", dir_name, " does not exist.");
|
||||||
}
|
}
|
||||||
TORCH_CHECK(file_stream_, "File ", file_name, " cannot be opened.");
|
TORCH_CHECK(file_stream_, "File ", file_name, " cannot be opened.");
|
||||||
writer_func_ = [this](const void* buf, size_t nbytes) -> size_t {
|
writer_func_ = [this](const void* buf, size_t nbytes) -> size_t {
|
||||||
|
|
@ -728,17 +756,20 @@ void PyTorchStreamWriter::writeEndOfFile() {
|
||||||
// destructor would would result in `std::terminate()`
|
// destructor would would result in `std::terminate()`
|
||||||
// See https://github.com/pytorch/pytorch/issues/87997/
|
// See https://github.com/pytorch/pytorch/issues/87997/
|
||||||
struct Finalizer {
|
struct Finalizer {
|
||||||
Finalizer(bool& var): var_(var) {}
|
Finalizer(bool& var) : var_(var) {}
|
||||||
~Finalizer() {
|
~Finalizer() {
|
||||||
var_ = true;
|
var_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool& var_;
|
bool& var_;
|
||||||
} f(finalized_);
|
} f(finalized_);
|
||||||
|
|
||||||
auto allRecords = getAllWrittenRecords();
|
auto allRecords = getAllWrittenRecords();
|
||||||
// If no ".data/version" or "version" record in the output model, rewrites version info
|
// If no ".data/version" or "version" record in the output model, rewrites
|
||||||
if(allRecords.find(".data/version") == allRecords.end() && allRecords.find("version") == allRecords.end()) {
|
// version info
|
||||||
|
if (allRecords.find(".data/version") == allRecords.end() &&
|
||||||
|
allRecords.find("version") == allRecords.end()) {
|
||||||
std::string version = std::to_string(version_);
|
std::string version = std::to_string(version_);
|
||||||
version.push_back('\n');
|
version.push_back('\n');
|
||||||
if (version_ >= 0x6L) {
|
if (version_ >= 0x6L) {
|
||||||
|
|
@ -749,7 +780,7 @@ void PyTorchStreamWriter::writeEndOfFile() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// If no "byteorder" record in the output model, rewrites byteorder info
|
// If no "byteorder" record in the output model, rewrites byteorder info
|
||||||
if(allRecords.find("byteorder") == allRecords.end()) {
|
if (allRecords.find("byteorder") == allRecords.end()) {
|
||||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
|
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
|
||||||
std::string byteorder = "little";
|
std::string byteorder = "little";
|
||||||
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
||||||
|
|
@ -808,9 +839,8 @@ void PyTorchStreamWriter::writeSerializationId() {
|
||||||
}
|
}
|
||||||
std::ostringstream serialization_id_oss;
|
std::ostringstream serialization_id_oss;
|
||||||
serialization_id_oss << std::setfill('0') << std::setw(20)
|
serialization_id_oss << std::setfill('0') << std::setw(20)
|
||||||
<< combined_record_name_hash
|
<< combined_record_name_hash << std::setfill('0')
|
||||||
<< std::setfill('0') << std::setw(20)
|
<< std::setw(20) << combined_uncomp_crc32_;
|
||||||
<< combined_uncomp_crc32_;
|
|
||||||
serialization_id_ = serialization_id_oss.str();
|
serialization_id_ = serialization_id_oss.str();
|
||||||
writeRecord(
|
writeRecord(
|
||||||
kSerializationIdRecordName,
|
kSerializationIdRecordName,
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@
|
||||||
#include "caffe2/serialize/read_adapter_interface.h"
|
#include "caffe2/serialize/read_adapter_interface.h"
|
||||||
#include "caffe2/serialize/versions.h"
|
#include "caffe2/serialize/versions.h"
|
||||||
|
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
typedef struct mz_zip_archive mz_zip_archive;
|
typedef struct mz_zip_archive mz_zip_archive;
|
||||||
}
|
}
|
||||||
|
|
@ -94,7 +93,8 @@ typedef struct mz_zip_archive mz_zip_archive;
|
||||||
namespace caffe2 {
|
namespace caffe2 {
|
||||||
namespace serialize {
|
namespace serialize {
|
||||||
|
|
||||||
static constexpr const char* kSerializationIdRecordName = ".data/serialization_id";
|
static constexpr const char* kSerializationIdRecordName =
|
||||||
|
".data/serialization_id";
|
||||||
|
|
||||||
struct MzZipReaderIterWrapper;
|
struct MzZipReaderIterWrapper;
|
||||||
|
|
||||||
|
|
@ -102,12 +102,15 @@ class TORCH_API ChunkRecordIterator {
|
||||||
public:
|
public:
|
||||||
~ChunkRecordIterator();
|
~ChunkRecordIterator();
|
||||||
|
|
||||||
// Read at most `chunkSize` into `buf`. Return the number of actual bytes read.
|
// Read at most `chunkSize` into `buf`. Return the number of actual bytes
|
||||||
|
// read.
|
||||||
size_t next(void* buf);
|
size_t next(void* buf);
|
||||||
size_t recordSize() const { return recordSize_; }
|
size_t recordSize() const {
|
||||||
|
return recordSize_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ChunkRecordIterator(
|
ChunkRecordIterator(
|
||||||
size_t recordSize,
|
size_t recordSize,
|
||||||
size_t chunkSize,
|
size_t chunkSize,
|
||||||
std::unique_ptr<MzZipReaderIterWrapper> iter);
|
std::unique_ptr<MzZipReaderIterWrapper> iter);
|
||||||
|
|
@ -129,35 +132,44 @@ class TORCH_API PyTorchStreamReader final {
|
||||||
// return dataptr, size
|
// return dataptr, size
|
||||||
std::tuple<at::DataPtr, size_t> getRecord(const std::string& name);
|
std::tuple<at::DataPtr, size_t> getRecord(const std::string& name);
|
||||||
// multi-thread getRecord
|
// multi-thread getRecord
|
||||||
std::tuple<at::DataPtr, size_t> getRecord(const std::string& name, std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders);
|
std::tuple<at::DataPtr, size_t> getRecord(
|
||||||
|
const std::string& name,
|
||||||
|
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders);
|
||||||
// inplace memory writing
|
// inplace memory writing
|
||||||
size_t getRecord(const std::string& name, void* dst, size_t n);
|
size_t getRecord(const std::string& name, void* dst, size_t n);
|
||||||
// inplace memory writing, multi-threads.
|
// inplace memory writing, multi-threads.
|
||||||
// When additionalReaders is empty, the default behavior is call getRecord(name, dst, n) with default reader
|
// When additionalReaders is empty, the default behavior is call
|
||||||
// This approach can be used for reading large tensors.
|
// getRecord(name, dst, n) with default reader This approach can be used for
|
||||||
size_t getRecord(const std::string& name, void* dst, size_t n,
|
// reading large tensors.
|
||||||
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders);
|
size_t getRecord(
|
||||||
|
const std::string& name,
|
||||||
|
void* dst,
|
||||||
|
size_t n,
|
||||||
|
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders);
|
||||||
size_t getRecord(
|
size_t getRecord(
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
void* dst,
|
void* dst,
|
||||||
size_t n,
|
size_t n,
|
||||||
size_t chunk_size,
|
size_t chunk_size,
|
||||||
void* buf,
|
void* buf,
|
||||||
const std::function<void(void*, const void*, size_t)>& memcpy_func = nullptr);
|
const std::function<void(void*, const void*, size_t)>& memcpy_func =
|
||||||
|
nullptr);
|
||||||
|
|
||||||
// Concurrent reading records with multiple readers.
|
// Concurrent reading records with multiple readers.
|
||||||
// additionalReaders are additional clients to access the underlying record at different offsets
|
// additionalReaders are additional clients to access the underlying record at
|
||||||
// and write to different trunks of buffers.
|
// different offsets and write to different trunks of buffers. If the overall
|
||||||
// If the overall size of the tensor is 10, and size of additionalReader is 2.
|
// size of the tensor is 10, and size of additionalReader is 2. The default
|
||||||
// The default thread will read [0,4), the additional reader will read [4,8).
|
// thread will read [0,4), the additional reader will read [4,8). The default
|
||||||
// The default reader will read [8,10).
|
// reader will read [8,10). The default reader will write to buffer[0,4), the
|
||||||
// The default reader will write to buffer[0,4), the additional reader will write to buffer[4,8),
|
// additional reader will write to buffer[4,8), the additional reader will
|
||||||
// the additional reader will write to buffer[8,10).
|
// write to buffer[8,10). When additionalReaders is empty, the default
|
||||||
// When additionalReaders is empty, the default behavior is call getRecord(name) with default reader
|
// behavior is call getRecord(name) with default reader This approach can be
|
||||||
// This approach can be used for reading large tensors.
|
// used for reading large tensors.
|
||||||
size_t getRecordMultiReaders(const std::string& name,
|
size_t getRecordMultiReaders(
|
||||||
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders,
|
const std::string& name,
|
||||||
void *dst, size_t n);
|
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders,
|
||||||
|
void* dst,
|
||||||
|
size_t n);
|
||||||
|
|
||||||
size_t getRecordSize(const std::string& name);
|
size_t getRecordSize(const std::string& name);
|
||||||
|
|
||||||
|
|
@ -181,9 +193,10 @@ class TORCH_API PyTorchStreamReader final {
|
||||||
void setShouldLoadDebugSymbol(bool should_load_debug_symbol) {
|
void setShouldLoadDebugSymbol(bool should_load_debug_symbol) {
|
||||||
load_debug_symbol_ = should_load_debug_symbol;
|
load_debug_symbol_ = should_load_debug_symbol;
|
||||||
}
|
}
|
||||||
void setAdditionalReaderSizeThreshold(const size_t& size){
|
void setAdditionalReaderSizeThreshold(const size_t& size) {
|
||||||
additional_reader_size_threshold_ = size;
|
additional_reader_size_threshold_ = size;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void init();
|
void init();
|
||||||
size_t read(uint64_t pos, char* buf, size_t n);
|
size_t read(uint64_t pos, char* buf, size_t n);
|
||||||
|
|
@ -205,9 +218,12 @@ class TORCH_API PyTorchStreamReader final {
|
||||||
|
|
||||||
class TORCH_API PyTorchStreamWriter final {
|
class TORCH_API PyTorchStreamWriter final {
|
||||||
public:
|
public:
|
||||||
explicit PyTorchStreamWriter(const std::string& archive_name, bool compute_crc32 = true);
|
|
||||||
explicit PyTorchStreamWriter(
|
explicit PyTorchStreamWriter(
|
||||||
const std::function<size_t(const void*, size_t)> writer_func, bool compute_crc32 = true);
|
const std::string& archive_name,
|
||||||
|
bool compute_crc32 = true);
|
||||||
|
explicit PyTorchStreamWriter(
|
||||||
|
const std::function<size_t(const void*, size_t)> writer_func,
|
||||||
|
bool compute_crc32 = true);
|
||||||
|
|
||||||
void setMinVersion(const uint64_t version);
|
void setMinVersion(const uint64_t version);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,9 @@
|
||||||
|
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include "caffe2/serialize/inline_container.h"
|
|
||||||
#include <c10/util/Logging.h>
|
#include <c10/util/Logging.h>
|
||||||
#include "c10/util/irange.h"
|
#include "c10/util/irange.h"
|
||||||
|
#include "caffe2/serialize/inline_container.h"
|
||||||
|
|
||||||
namespace caffe2 {
|
namespace caffe2 {
|
||||||
namespace serialize {
|
namespace serialize {
|
||||||
|
|
@ -77,9 +77,12 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
|
||||||
ASSERT_EQ(memcmp(dst.data(), data1.data(), size), 0);
|
ASSERT_EQ(memcmp(dst.data(), data1.data(), size), 0);
|
||||||
// chunked getRecord() test
|
// chunked getRecord() test
|
||||||
ret = reader.getRecord(
|
ret = reader.getRecord(
|
||||||
"key1", dst.data(), size, 3, buf.data(), [](void* dst, const void* src, size_t n) {
|
"key1",
|
||||||
memcpy(dst, src, n);
|
dst.data(),
|
||||||
});
|
size,
|
||||||
|
3,
|
||||||
|
buf.data(),
|
||||||
|
[](void* dst, const void* src, size_t n) { memcpy(dst, src, n); });
|
||||||
ASSERT_EQ(ret, size);
|
ASSERT_EQ(ret, size);
|
||||||
ASSERT_EQ(memcmp(dst.data(), data1.data(), size), 0);
|
ASSERT_EQ(memcmp(dst.data(), data1.data(), size), 0);
|
||||||
|
|
||||||
|
|
@ -97,9 +100,12 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
|
||||||
ASSERT_EQ(memcmp(dst.data(), data2.data(), size), 0);
|
ASSERT_EQ(memcmp(dst.data(), data2.data(), size), 0);
|
||||||
// chunked getRecord() test
|
// chunked getRecord() test
|
||||||
ret = reader.getRecord(
|
ret = reader.getRecord(
|
||||||
"key2", dst.data(), size, 3, buf.data(), [](void* dst, const void* src, size_t n) {
|
"key2",
|
||||||
memcpy(dst, src, n);
|
dst.data(),
|
||||||
});
|
size,
|
||||||
|
3,
|
||||||
|
buf.data(),
|
||||||
|
[](void* dst, const void* src, size_t n) { memcpy(dst, src, n); });
|
||||||
ASSERT_EQ(ret, size);
|
ASSERT_EQ(ret, size);
|
||||||
ASSERT_EQ(memcmp(dst.data(), data2.data(), size), 0);
|
ASSERT_EQ(memcmp(dst.data(), data2.data(), size), 0);
|
||||||
// clean up
|
// clean up
|
||||||
|
|
@ -107,7 +113,6 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(PyTorchStreamWriterAndReader, LoadWithMultiThreads) {
|
TEST(PyTorchStreamWriterAndReader, LoadWithMultiThreads) {
|
||||||
|
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
// write records through writers
|
// write records through writers
|
||||||
PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
|
PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
|
||||||
|
|
@ -156,7 +161,7 @@ TEST(PyTorchStreamWriterAndReader, LoadWithMultiThreads) {
|
||||||
|
|
||||||
// Test getRecord(name, additional_readers)
|
// Test getRecord(name, additional_readers)
|
||||||
std::vector<std::shared_ptr<ReadAdapterInterface>> additionalReader;
|
std::vector<std::shared_ptr<ReadAdapterInterface>> additionalReader;
|
||||||
for(int i=0; i<10; ++i){
|
for (int i = 0; i < 10; ++i) {
|
||||||
// Test various sized additional readers.
|
// Test various sized additional readers.
|
||||||
std::tie(data_ptr, ret) = reader.getRecord("key1", additionalReader);
|
std::tie(data_ptr, ret) = reader.getRecord("key1", additionalReader);
|
||||||
ASSERT_EQ(ret, size1);
|
ASSERT_EQ(ret, size1);
|
||||||
|
|
@ -170,7 +175,7 @@ TEST(PyTorchStreamWriterAndReader, LoadWithMultiThreads) {
|
||||||
// Inplace multi-threading getRecord(name, dst, n, additional_readers) test
|
// Inplace multi-threading getRecord(name, dst, n, additional_readers) test
|
||||||
additionalReader.clear();
|
additionalReader.clear();
|
||||||
std::vector<uint8_t> dst1(size1), dst2(size2);
|
std::vector<uint8_t> dst1(size1), dst2(size2);
|
||||||
for(int i=0; i<10; ++i){
|
for (int i = 0; i < 10; ++i) {
|
||||||
// Test various sizes of read threads
|
// Test various sizes of read threads
|
||||||
additionalReader.push_back(std::make_unique<IStreamAdapter>(&iss));
|
additionalReader.push_back(std::make_unique<IStreamAdapter>(&iss));
|
||||||
|
|
||||||
|
|
@ -324,7 +329,7 @@ TEST(PytorchStreamWriterAndReader, ValidSerializationId) {
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
|
||||||
std::array<char, 127> data1;
|
std::array<char, 127> data1;
|
||||||
|
|
||||||
for (auto i: c10::irange(data1.size())) {
|
for (auto i : c10::irange(data1.size())) {
|
||||||
data1[i] = data1.size() - i;
|
data1[i] = data1.size() - i;
|
||||||
}
|
}
|
||||||
writer.writeRecord("key1.debug_pkl", data1.data(), data1.size());
|
writer.writeRecord("key1.debug_pkl", data1.data(), data1.size());
|
||||||
|
|
@ -361,7 +366,10 @@ TEST(PytorchStreamWriterAndReader, SkipDuplicateSerializationIdRecords) {
|
||||||
});
|
});
|
||||||
|
|
||||||
std::string dup_serialization_id = "dup-serialization-id";
|
std::string dup_serialization_id = "dup-serialization-id";
|
||||||
writer.writeRecord(kSerializationIdRecordName, dup_serialization_id.c_str(), dup_serialization_id.size());
|
writer.writeRecord(
|
||||||
|
kSerializationIdRecordName,
|
||||||
|
dup_serialization_id.c_str(),
|
||||||
|
dup_serialization_id.size());
|
||||||
|
|
||||||
const std::unordered_set<std::string>& written_records =
|
const std::unordered_set<std::string>& written_records =
|
||||||
writer.getAllWrittenRecords();
|
writer.getAllWrittenRecords();
|
||||||
|
|
@ -410,13 +418,12 @@ TEST(PytorchStreamWriterAndReader, LogAPIUsageMetadata) {
|
||||||
std::map<std::string, std::map<std::string, std::string>> expected_logs = {
|
std::map<std::string, std::map<std::string, std::string>> expected_logs = {
|
||||||
{"pytorch.stream.writer.metadata",
|
{"pytorch.stream.writer.metadata",
|
||||||
{{"serialization_id", writer.serializationId()},
|
{{"serialization_id", writer.serializationId()},
|
||||||
{"file_name", "archive"},
|
{"file_name", "archive"},
|
||||||
{"file_size", str(oss.str().length())}}},
|
{"file_size", str(oss.str().length())}}},
|
||||||
{"pytorch.stream.reader.metadata",
|
{"pytorch.stream.reader.metadata",
|
||||||
{{"serialization_id", writer.serializationId()},
|
{{"serialization_id", writer.serializationId()},
|
||||||
{"file_name", "archive"},
|
{"file_name", "archive"},
|
||||||
{"file_size", str(iss.str().length())}}}
|
{"file_size", str(iss.str().length())}}}};
|
||||||
};
|
|
||||||
ASSERT_EQ(expected_logs, logs);
|
ASSERT_EQ(expected_logs, logs);
|
||||||
|
|
||||||
// reset logger
|
// reset logger
|
||||||
|
|
@ -433,7 +440,8 @@ INSTANTIATE_TEST_SUITE_P(
|
||||||
|
|
||||||
TEST_P(ChunkRecordIteratorTest, ChunkRead) {
|
TEST_P(ChunkRecordIteratorTest, ChunkRead) {
|
||||||
auto chunkSize = GetParam();
|
auto chunkSize = GetParam();
|
||||||
std::string zipFileName = "output_chunk_" + std::to_string(chunkSize) + ".zip";
|
std::string zipFileName =
|
||||||
|
"output_chunk_" + std::to_string(chunkSize) + ".zip";
|
||||||
const char* fileName = zipFileName.c_str();
|
const char* fileName = zipFileName.c_str();
|
||||||
const std::string recordName = "key1";
|
const std::string recordName = "key1";
|
||||||
const size_t tensorDataSizeInBytes = 1000;
|
const size_t tensorDataSizeInBytes = 1000;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user