add sorting policy to ChunkDataset (#23053)

Summary:
Add a sorting policy to ChunkDataset.

This is considered an advanced parameter for developers who want to apply a 'sorting policy' to the chunk data before sampling into minibatch.

Different than the collate method, this policy is applied on the chunk level instead of minibatch level. When a chunk of data is loaded (multiple chunks if cross_chunk_shuffle_count_ is greater than 1), this policy is targeting to the full loaded data. It will be useful if developers want to perform some pre-processing (like bucketing) to the chunk data before example sampler samples the data.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23053

Differential Revision: D16537692

Pulled By: colesbury

fbshipit-source-id: cd21ed40ab787a18b8c6dd304e5b806a7a45e6ba
This commit is contained in:
xzhu1900 2019-07-29 12:12:44 -07:00 committed by Facebook Github Bot
parent a356276d79
commit 31f1928096
2 changed files with 100 additions and 1 deletions

View File

@ -2224,4 +2224,85 @@ TEST(DataLoaderTest, ChunkDatasetCrossChunkShuffle) {
std::equal(result.begin(), result.end(), expected_result.begin()));
}
}
}
TEST(DataLoaderTest, CustomPreprocessPolicy) {
const size_t chunk_size = 5;
const size_t batch_size = 10;
struct D : public datasets::ChunkDataReader<int> {
public:
using BatchType = datasets::ChunkDataReader<int>::ChunkType;
D(size_t chunk_count) : chunk_count_(chunk_count) {}
BatchType read_chunk(size_t chunk_index) override {
BatchType batch_data(chunk_size);
auto rand_gen = []() { return std::rand() % 100; };
std::generate(batch_data.begin(), batch_data.end(), rand_gen);
return batch_data;
}
size_t chunk_count() override {
return chunk_count_;
};
void reset() override{};
size_t chunk_count_;
};
// custom preprocessing policy - sort the data ascendingly
auto sorting_policy = [](std::vector<int>& raw_batch_data) {
std::sort(raw_batch_data.begin(), raw_batch_data.end());
};
std::function<void(std::vector<int>&)> policy_function =
sorting_policy;
const size_t prefetch_count = 1;
const size_t cache_size = 10;
const size_t cross_chunk_shuffle_counts[] = {1, 2};
const size_t chunk_counts[] = {3, 4};
samplers::SequentialSampler chunk_sampler(0);
for (auto chunk_count : chunk_counts) {
for (auto cross_chunk_shuffle_count : cross_chunk_shuffle_counts) {
D data_reader(chunk_count);
datasets::SharedBatchDataset<datasets::ChunkDataset<
D,
samplers::SequentialSampler,
samplers::SequentialSampler>>
dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
D,
samplers::SequentialSampler,
samplers::SequentialSampler>>(
data_reader,
chunk_sampler,
chunk_sampler,
datasets::ChunkDatasetOptions(
prefetch_count,
batch_size,
cache_size,
cross_chunk_shuffle_count),
policy_function);
auto data_loader = torch::data::make_data_loader(
dataset, DataLoaderOptions(batch_size).workers(0));
std::vector<int> result;
for (auto iterator = data_loader->begin(); iterator != data_loader->end();
++iterator) {
auto batch_result = *iterator;
if (batch_result.size() > chunk_size * cross_chunk_shuffle_count) {
for (int i = 0; i < batch_result.size(); i += chunk_size) {
ASSERT_TRUE(std::is_sorted(
batch_result.begin() + i,
batch_result.begin() + i + chunk_size));
}
} else {
ASSERT_TRUE(std::is_sorted(batch_result.begin(), batch_result.end()));
}
}
}
}
}

View File

@ -320,11 +320,14 @@ class ChunkDataset final
ChunkReader chunk_reader,
ChunkSampler chunk_sampler,
ExampleSampler example_sampler,
ChunkDatasetOptions options)
ChunkDatasetOptions options,
std::function<void(UnwrappedBatchType&)> preprocessing_policy =
std::function<void(UnwrappedBatchType&)>())
: chunk_reader_(std::move(chunk_reader)),
chunk_sampler_(std::move(chunk_sampler)),
example_sampler_(std::move(example_sampler)),
options_(std::move(options)),
preprocessing_policy_(preprocessing_policy),
quit_worker_(false),
running_preloaders_(0),
load_checkpoint_(false) {}
@ -436,6 +439,9 @@ class ChunkDataset final
std::move(
chunk_data.begin(), chunk_data.end(), std::back_inserter(data));
}
if (preprocessing_policy_) {
preprocessing_policy_(data);
}
if (!data.empty()) { // skip empty chunks.
batch_buffer_->add_chunk_data(std::move(data));
}
@ -483,6 +489,18 @@ class ChunkDataset final
/// The options the Dataset was configured with.
const ChunkDatasetOptions options_;
// function pointer wrapper to apply custom processing over chunk data. This is
// considered an advanced parameter for developers who want to apply a
// pre-process to the chunk data before sampling into minibatch.
// Different than the collate function, this policy is applied on the chunk
// level, instead of minibatch level. When a chunk of data is loaded (multiple
// chunks if cross_chunk_shuffle_count_ is greater than 1), this policy is
// applied to the full loaded data. It is useful if developers want to
// perform pre-processing (like bucketing) to the chunk data before
// example sampler samples the data. By default it's an empty pointer and no
// action will be taken.
std::function<void(UnwrappedBatchType&)> preprocessing_policy_;
// indicate whether the worker thread can be teared down
std::atomic<bool> quit_worker_;