mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
add sorting policy to ChunkDataset (#23053)
Summary: Add a sorting policy to ChunkDataset. This is considered an advanced parameter for developers who want to apply a 'sorting policy' to the chunk data before sampling into minibatch. Different than the collate method, this policy is applied on the chunk level instead of minibatch level. When a chunk of data is loaded (multiple chunks if cross_chunk_shuffle_count_ is greater than 1), this policy is targeting to the full loaded data. It will be useful if developers want to perform some pre-processing (like bucketing) to the chunk data before example sampler samples the data. Pull Request resolved: https://github.com/pytorch/pytorch/pull/23053 Differential Revision: D16537692 Pulled By: colesbury fbshipit-source-id: cd21ed40ab787a18b8c6dd304e5b806a7a45e6ba
This commit is contained in:
parent
a356276d79
commit
31f1928096
|
|
@ -2224,4 +2224,85 @@ TEST(DataLoaderTest, ChunkDatasetCrossChunkShuffle) {
|
|||
std::equal(result.begin(), result.end(), expected_result.begin()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DataLoaderTest, CustomPreprocessPolicy) {
|
||||
const size_t chunk_size = 5;
|
||||
const size_t batch_size = 10;
|
||||
|
||||
struct D : public datasets::ChunkDataReader<int> {
|
||||
public:
|
||||
using BatchType = datasets::ChunkDataReader<int>::ChunkType;
|
||||
D(size_t chunk_count) : chunk_count_(chunk_count) {}
|
||||
|
||||
BatchType read_chunk(size_t chunk_index) override {
|
||||
BatchType batch_data(chunk_size);
|
||||
auto rand_gen = []() { return std::rand() % 100; };
|
||||
std::generate(batch_data.begin(), batch_data.end(), rand_gen);
|
||||
return batch_data;
|
||||
}
|
||||
|
||||
size_t chunk_count() override {
|
||||
return chunk_count_;
|
||||
};
|
||||
|
||||
void reset() override{};
|
||||
size_t chunk_count_;
|
||||
};
|
||||
|
||||
// custom preprocessing policy - sort the data ascendingly
|
||||
auto sorting_policy = [](std::vector<int>& raw_batch_data) {
|
||||
std::sort(raw_batch_data.begin(), raw_batch_data.end());
|
||||
};
|
||||
std::function<void(std::vector<int>&)> policy_function =
|
||||
sorting_policy;
|
||||
|
||||
const size_t prefetch_count = 1;
|
||||
const size_t cache_size = 10;
|
||||
const size_t cross_chunk_shuffle_counts[] = {1, 2};
|
||||
const size_t chunk_counts[] = {3, 4};
|
||||
|
||||
samplers::SequentialSampler chunk_sampler(0);
|
||||
|
||||
for (auto chunk_count : chunk_counts) {
|
||||
for (auto cross_chunk_shuffle_count : cross_chunk_shuffle_counts) {
|
||||
D data_reader(chunk_count);
|
||||
|
||||
datasets::SharedBatchDataset<datasets::ChunkDataset<
|
||||
D,
|
||||
samplers::SequentialSampler,
|
||||
samplers::SequentialSampler>>
|
||||
dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
|
||||
D,
|
||||
samplers::SequentialSampler,
|
||||
samplers::SequentialSampler>>(
|
||||
data_reader,
|
||||
chunk_sampler,
|
||||
chunk_sampler,
|
||||
datasets::ChunkDatasetOptions(
|
||||
prefetch_count,
|
||||
batch_size,
|
||||
cache_size,
|
||||
cross_chunk_shuffle_count),
|
||||
policy_function);
|
||||
|
||||
auto data_loader = torch::data::make_data_loader(
|
||||
dataset, DataLoaderOptions(batch_size).workers(0));
|
||||
|
||||
std::vector<int> result;
|
||||
for (auto iterator = data_loader->begin(); iterator != data_loader->end();
|
||||
++iterator) {
|
||||
auto batch_result = *iterator;
|
||||
if (batch_result.size() > chunk_size * cross_chunk_shuffle_count) {
|
||||
for (int i = 0; i < batch_result.size(); i += chunk_size) {
|
||||
ASSERT_TRUE(std::is_sorted(
|
||||
batch_result.begin() + i,
|
||||
batch_result.begin() + i + chunk_size));
|
||||
}
|
||||
} else {
|
||||
ASSERT_TRUE(std::is_sorted(batch_result.begin(), batch_result.end()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -320,11 +320,14 @@ class ChunkDataset final
|
|||
ChunkReader chunk_reader,
|
||||
ChunkSampler chunk_sampler,
|
||||
ExampleSampler example_sampler,
|
||||
ChunkDatasetOptions options)
|
||||
ChunkDatasetOptions options,
|
||||
std::function<void(UnwrappedBatchType&)> preprocessing_policy =
|
||||
std::function<void(UnwrappedBatchType&)>())
|
||||
: chunk_reader_(std::move(chunk_reader)),
|
||||
chunk_sampler_(std::move(chunk_sampler)),
|
||||
example_sampler_(std::move(example_sampler)),
|
||||
options_(std::move(options)),
|
||||
preprocessing_policy_(preprocessing_policy),
|
||||
quit_worker_(false),
|
||||
running_preloaders_(0),
|
||||
load_checkpoint_(false) {}
|
||||
|
|
@ -436,6 +439,9 @@ class ChunkDataset final
|
|||
std::move(
|
||||
chunk_data.begin(), chunk_data.end(), std::back_inserter(data));
|
||||
}
|
||||
if (preprocessing_policy_) {
|
||||
preprocessing_policy_(data);
|
||||
}
|
||||
if (!data.empty()) { // skip empty chunks.
|
||||
batch_buffer_->add_chunk_data(std::move(data));
|
||||
}
|
||||
|
|
@ -483,6 +489,18 @@ class ChunkDataset final
|
|||
/// The options the Dataset was configured with.
|
||||
const ChunkDatasetOptions options_;
|
||||
|
||||
// function pointer wrapper to apply custom processing over chunk data. This is
|
||||
// considered an advanced parameter for developers who want to apply a
|
||||
// pre-process to the chunk data before sampling into minibatch.
|
||||
// Different than the collate function, this policy is applied on the chunk
|
||||
// level, instead of minibatch level. When a chunk of data is loaded (multiple
|
||||
// chunks if cross_chunk_shuffle_count_ is greater than 1), this policy is
|
||||
// applied to the full loaded data. It is useful if developers want to
|
||||
// perform pre-processing (like bucketing) to the chunk data before
|
||||
// example sampler samples the data. By default it's an empty pointer and no
|
||||
// action will be taken.
|
||||
std::function<void(UnwrappedBatchType&)> preprocessing_policy_;
|
||||
|
||||
// indicate whether the worker thread can be teared down
|
||||
std::atomic<bool> quit_worker_;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user