ChunkDataset checkpoint support (#21889)

Summary:
When dealing with large scale dataset, it is handy if we can save the dataset status and resume later. Especially in cases where some unexpected crash happens, user don't need to start over the whole dataset from begining. Instead, they can reload it from the last checkpoint.

This change adds support for checkpoint save/load logic in ChunkDataset.

On ChunkDataset construction, user can specify a file name from which to load the checkpoint. If it is empty, default to start from fresh; otherwise the ChunkDataset will 'fast forward' the chunk sampler to the corresponding checkpoint.

The user can also call ChunkDataset::save() to serialize current status to a file, which can be used later.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21889

Differential Revision: D16024582

Pulled By: ailzhang

fbshipit-source-id: 1862ab5116f94c9d29da174ce04a91041d06cad5
This commit is contained in:
xzhu1900 2019-06-26 22:51:08 -07:00 committed by Facebook Github Bot
parent 30d890c672
commit f39b6624ba
3 changed files with 264 additions and 9 deletions

View File

@ -8,6 +8,7 @@
#include <test/cpp/api/support.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/tempfile.h>
#include <algorithm>
#include <chrono>
@ -98,7 +99,9 @@ TEST(DataTest, ChunkDataSetWithInvalidInitParameter) {
samplers::SequentialSampler sampler(0);
auto initialization_function =
[&](size_t preloader_count, size_t batch_size, size_t cache_size) {
[&](size_t preloader_count,
size_t batch_size,
size_t cache_size) {
datasets::SharedBatchDataset<datasets::ChunkDataset<
DummyChunkDataReader,
samplers::SequentialSampler,
@ -111,7 +114,9 @@ TEST(DataTest, ChunkDataSetWithInvalidInitParameter) {
sampler,
sampler,
datasets::ChunkDatasetOptions(
preloader_count, batch_size, cache_size));
preloader_count,
batch_size,
cache_size));
};
ASSERT_THROWS_WITH(
@ -1465,6 +1470,8 @@ TEST(DataLoaderTest, StatefulDatasetWithNoWorkers) {
void reset() override {
counter = 0;
}
void save(torch::serialize::OutputArchive& archive) const override{};
void load(torch::serialize::InputArchive& archive) override {}
int counter = 0;
};
@ -1501,6 +1508,8 @@ TEST(DataLoaderTest, StatefulDatasetWithManyWorkers) {
void reset() override {
counter = 0;
}
void save(torch::serialize::OutputArchive& archive) const override{};
void load(torch::serialize::InputArchive& archive) override {}
int counter = 0;
std::mutex mutex;
};
@ -1538,6 +1547,8 @@ TEST(DataLoaderTest, StatefulDatasetWithMap) {
void reset() override {
counter = 0;
}
void save(torch::serialize::OutputArchive& archive) const override{};
void load(torch::serialize::InputArchive& archive) override {}
int counter = 0;
};
@ -1585,6 +1596,8 @@ TEST(DataLoaderTest, StatefulDatasetWithCollate) {
void reset() override {
counter = 0;
}
void save(torch::serialize::OutputArchive& archive) const override{};
void load(torch::serialize::InputArchive& archive) override {}
int counter = 0;
};
@ -1880,4 +1893,203 @@ TEST(DataLoaderTest, ChunkDatasetDoesNotHang) {
// to fill the batch buffer but it is not draining. Still we need to exit
// cleanly.
auto iterator = data_loader->begin();
}
// Test ChunkDataset save function.
// Note [save/load ChunkDataset as ChunkSampler]:
// The chunk sampler inside ChunkDataset is used in a separate thread pool other
// than the main thread. Thus it is very hard to accurately estimate its status
// when ChunkDataset::save/ChunkDataset::load is called. For the pure purpose of
// testing, we utilize the implementation fact that the file format for sampler
// serialization is the same as ChunkDataset serialization, and manually control
// the chunk sampler by calling the sampler's save/load method for value
// validation. This is only for testing the specific save/load functionality. In
// real user case, the user should still use matching ChunkDataset::save and
// ChunkDataset::load method.
TEST(DataLoaderTest, ChunkDatasetSave) {
const size_t chunk_count_ = 6;
const size_t chunk_size = 10;
struct DummyTestChunkDataReader : datasets::ChunkDataReader<int> {
public:
using BatchType = datasets::ChunkDataReader<int>::ChunkType;
BatchType read_chunk(size_t chunk_index) override {
return batch_data_;
}
size_t chunk_count() override {
return chunk_count_;
};
void reset() override{};
BatchType batch_data_ = BatchType(chunk_size, 0);
};
const size_t prefetch_count = 1;
const size_t batch_size = chunk_size;
const size_t dataloader_worker_count = 0;
samplers::SequentialSampler sampler(0);
const int epoch_count = 2;
DummyTestChunkDataReader data_reader;
// tested save_intervals
const size_t save_intervals[] = {1, 2};
using datasets::ChunkDatasetOptions;
for (auto save_interval : save_intervals) {
auto tempfile = c10::make_tempfile();
datasets::SharedBatchDataset<datasets::ChunkDataset<
DummyTestChunkDataReader,
samplers::SequentialSampler,
samplers::SequentialSampler>>
dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
DummyTestChunkDataReader,
samplers::SequentialSampler,
samplers::SequentialSampler>>(
data_reader,
sampler,
sampler,
ChunkDatasetOptions(
prefetch_count, batch_size, chunk_size /*cache size*/));
auto data_loader = torch::data::make_data_loader(
dataset,
DataLoaderOptions(batch_size).workers(dataloader_worker_count));
for (int epoch_index = 0; epoch_index < epoch_count; ++epoch_index) {
int iteration_count = 0;
for (auto iterator = data_loader->begin(); iterator != data_loader->end();
++iterator, ++iteration_count) {
if ((iteration_count + 1) % save_interval == 0) {
torch::save(*dataset, tempfile.name);
samplers::SequentialSampler new_sampler(0);
// See Note [save/load ChunkDataset as ChunkSampler]
torch::load(new_sampler, tempfile.name);
// Verify save logic. For ChunkDataset, the chunk data is stored in a
// cache inside the dataset. One pool of threads are constantly
// writing to the cache, and a different pool of thread are constantly
// reading from the cache. Due to the nature of asynchronization, at
// the time of get_batch(), which chunk is written to the cache is not
// fully deterministic.
// But we can still calculate a restricted window on the expected
// output, hence verify the logic. In this test, the cache size is
// configured to be the same as chunk size and batch size. So the
// chunk data is written to the cache one by one. Only the current
// batch is retrieved, the next chunk is writen. Now in iteration 0,
// after the first batch is retrieved, when we save the dataset
// statues, there are three possible scenarios for the writer thread:
// 1. it hasn't started loading the next chunk data yet, so the
// sequential sampler index is still 0;
// 2. it started to load the second chunk, so the sequencial sampler
// index is at 1;
// 3. it finished loading the second chunk, and start to load the
// third chunk, because the cache is still fully occupied by the data
// from the second chunk, it is waiting to write to the cache. At this
// point, the sampler index is at 2.
// So now we have a window of [0, 2], which is what we expected the
// sampler to save the index from. Now noted for sequential sampler,
// it advances to the next index automatically in the call next(). So
// when save the index, it saves the next index in stead of the
// current one. In other word, after getting the first index from
// sequential sampler, it already moves to the second index. So when
// we save it, it is the second index we save. As a result,
// we need to advance the window by one. Now we have the expected
// window of [1, 3].
// This analysis applies to all scenarios. So extend it to a more
// general case: the expected saved index should falling into the
// range of [iteration, iteration + 3], which is the validation
// below.
ASSERT_TRUE(
new_sampler.index() >= iteration_count + 1 &&
new_sampler.index() <= iteration_count + 3);
}
}
}
}
}
// Test ChunkDataset load function.
TEST(DataLoaderTest, ChunkDatasetLoad) {
auto tempfile = c10::make_tempfile();
const size_t prefetch_count = 1;
const size_t batch_size = 10;
const size_t dataloader_worker_count = 0;
const size_t save_interval = 2;
DummyChunkDataReader data_reader;
samplers::SequentialSampler sampler(0);
const size_t skipped_chunk = 2;
// Configure sampler to skip 2 chunks
{
sampler.reset(data_reader.chunk_count());
sampler.next(skipped_chunk);
// See Note [save/load ChunkDataset as ChunkSampler]
torch::save(sampler, tempfile.name);
}
// test functionality across epoch boundary. The first epoch should be
// affected by the checkpoint, but the second should start normally.
const int epoch_count = 2;
datasets::SharedBatchDataset<datasets::ChunkDataset<
DummyChunkDataReader,
samplers::SequentialSampler,
samplers::SequentialSampler>>
dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
DummyChunkDataReader,
samplers::SequentialSampler,
samplers::SequentialSampler>>(
data_reader,
sampler,
sampler,
datasets::ChunkDatasetOptions(
prefetch_count, batch_size, 20 /*cache size*/));
torch::load(*dataset, tempfile.name);
auto data_loader = torch::data::make_data_loader(
dataset, DataLoaderOptions(batch_size).workers(dataloader_worker_count));
for (int epoch_index = 0; epoch_index < epoch_count; ++epoch_index) {
int iteration_count = 0;
// For the first epoch, the returned batch should be returned from the
// third chunk, because the check point skipped the first two chunks. But
// for the next epoch, it should start from the first batch.
int initial_value = epoch_index == 0 ? 15 : 0;
for (auto iterator = data_loader->begin(); iterator != data_loader->end();
++iterator, ++iteration_count) {
DummyChunkDataReader::BatchType batch = *iterator;
std::vector<int> expected_result;
size_t expected_size = (epoch_index > 0 && iteration_count == 3) ? 5 : 10;
expected_result.resize(expected_size);
std::iota(expected_result.begin(), expected_result.end(), initial_value);
ASSERT_EQ(batch.size(), expected_result.size());
ASSERT_TRUE(
std::equal(batch.begin(), batch.end(), expected_result.begin()));
initial_value += batch_size;
}
}
samplers::SequentialSampler new_sampler(0);
// See Note [save/load ChunkDataset as ChunkSampler]
torch::load(new_sampler, tempfile.name);
ASSERT_EQ(new_sampler.index(), skipped_chunk);
}

View File

@ -7,6 +7,8 @@
#include <queue>
#include <thread>
#include <torch/serialize.h>
namespace torch {
namespace data {
namespace datasets {
@ -270,7 +272,7 @@ struct ChunkDatasetOptions {
/// The size of each batch.
TORCH_ARG(size_t, batch_size);
// the capacity of the queue for batch caching.
/// The capacity of the queue for batch caching.
TORCH_ARG(size_t, cache_size) = 2048;
};
@ -308,7 +310,8 @@ class ChunkDataset final
example_sampler_(std::move(example_sampler)),
options_(std::move(options)),
quit_worker_(false),
running_preloaders_(0) {}
running_preloaders_(0),
load_checkpoint_(false) {}
virtual ~ChunkDataset() {
// stop batch buffer first.
@ -332,7 +335,6 @@ class ChunkDataset final
"The requested batch size does not match with the initialized batch size.\n"
" The requested batch size is ", batch_size,
", while the dataset is created with batch size equal to ", options_.batch_size_);
return batch_buffer_->get_batch();
}
@ -352,9 +354,11 @@ class ChunkDataset final
free_workers();
preload_threads_.clear();
chunk_reader_.reset();
chunk_sampler_.reset(chunk_reader_.chunk_count());
if (!load_checkpoint_){
chunk_reader_.reset();
chunk_sampler_.reset(chunk_reader_.chunk_count());
load_checkpoint_ = false;
}
// Throw out any existing cached batch in the buffer and re-creates a new
// chunk buffer.
@ -385,6 +389,17 @@ class ChunkDataset final
return chunk_sampler_;
}
void save(serialize::OutputArchive& archive) const override {
std::lock_guard<std::mutex> lock(chunk_index_guard_);
chunk_sampler_.save(archive);
}
void load(serialize::InputArchive& archive) override{
std::lock_guard<std::mutex> lock(chunk_index_guard_);
chunk_sampler_.load(archive);
load_checkpoint_ = true;
}
private:
/// running on worker thread to preload chunk data.
void preloader(size_t id) {
@ -455,7 +470,10 @@ class ChunkDataset final
std::atomic<size_t> running_preloaders_;
// mutex to synchronize chunk sampler next() call.
std::mutex chunk_index_guard_;
mutable std::mutex chunk_index_guard_;
// boolean value to indicate whether we need to load the checkpoint for chunk_sampler_.
bool load_checkpoint_;
};
} // namespace datasets
} // namespace data

View File

@ -30,7 +30,32 @@ class StatefulDataset
public:
/// Resets internal state of the dataset.
virtual void reset() = 0;
/// Saves the statefulDataset's state to OutputArchive.
virtual void save(serialize::OutputArchive& archive) const = 0;
/// Deserializes the statefulDataset's state from the `archive`.
virtual void load(serialize::InputArchive& archive) = 0;
};
/// Serializes a statefulDataset to `OutputArchive`.
template <typename... Args>
serialize::OutputArchive& operator<<(
serialize::OutputArchive& archive,
const StatefulDataset<Args...>& statefulDataset) {
statefulDataset.save(archive);
return archive;
}
/// Deserializes a statefulDataset from an `InputArchive`.
template <typename... Args>
serialize::InputArchive& operator>>(
serialize::InputArchive& archive,
StatefulDataset<Args...>& statefulDataset) {
statefulDataset.load(archive);
return archive;
}
} // namespace datasets
} // namespace data
} // namespace torch