mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: As GoogleTest `TEST` macro is non-compliant with it as well as `DEFINE_DISPATCH` All changes but the ones to `.clang-tidy` are generated using following script: ``` for i in `find . -type f -iname "*.c*" -or -iname "*.h"|xargs grep cppcoreguidelines-avoid-non-const-global-variables|cut -f1 -d:|sort|uniq`; do sed -i "/\/\/ NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)/d" $i; done ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/62008 Reviewed By: driazati, r-barnes Differential Revision: D29838584 Pulled By: malfet fbshipit-source-id: 1b2f8602c945bd4ce50a9bfdd204755556e31d13
278 lines
8.9 KiB
C++
278 lines
8.9 KiB
C++
#include <memory>
|
|
#include <string>
|
|
#include <vector>
|
|
#include "caffe2/core/operator.h"
|
|
#include "caffe2/core/tensor.h"
|
|
#include "caffe2/operators/map_ops.h"
|
|
|
|
namespace caffe2 {
|
|
namespace {
|
|
|
|
template <class Context>
|
|
class ReservoirSamplingOp final : public Operator<Context> {
|
|
public:
|
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
ReservoirSamplingOp(const OperatorDef operator_def, Workspace* ws)
|
|
: Operator<Context>(operator_def, ws),
|
|
numToCollect_(
|
|
OperatorBase::GetSingleArgument<int>("num_to_collect", -1)) {
|
|
CAFFE_ENFORCE(numToCollect_ > 0);
|
|
}
|
|
|
|
bool RunOnDevice() override {
|
|
auto& mutex = OperatorBase::Input<std::unique_ptr<std::mutex>>(MUTEX);
|
|
std::lock_guard<std::mutex> guard(*mutex);
|
|
|
|
auto* output = Output(RESERVOIR);
|
|
const auto& input = Input(DATA);
|
|
|
|
CAFFE_ENFORCE_GE(input.dim(), 1);
|
|
|
|
bool output_initialized = output->numel() > 0 &&
|
|
(static_cast<std::shared_ptr<std::vector<TensorCPU>>*>(
|
|
output->raw_mutable_data(input.dtype()))[0] != nullptr);
|
|
|
|
if (output_initialized) {
|
|
CAFFE_ENFORCE_EQ(output->dim(), input.dim());
|
|
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
|
for (size_t i = 1; i < input.dim(); ++i) {
|
|
CAFFE_ENFORCE_EQ(output->size(i), input.size(i));
|
|
}
|
|
}
|
|
|
|
auto num_entries = input.sizes()[0];
|
|
|
|
if (!output_initialized) {
|
|
// IMPORTANT: Force the output to have the right type before reserving,
|
|
// so that the output gets the right capacity
|
|
auto dims = input.sizes().vec();
|
|
dims[0] = 0;
|
|
output->Resize(dims);
|
|
output->raw_mutable_data(input.dtype());
|
|
output->ReserveSpace(numToCollect_);
|
|
}
|
|
|
|
auto* pos_to_object =
|
|
OutputSize() > POS_TO_OBJECT ? Output(POS_TO_OBJECT) : nullptr;
|
|
if (pos_to_object) {
|
|
if (!output_initialized) {
|
|
// Cleaning up in case the reservoir got reset.
|
|
pos_to_object->Resize(0);
|
|
pos_to_object->template mutable_data<int64_t>();
|
|
pos_to_object->ReserveSpace(numToCollect_);
|
|
}
|
|
}
|
|
|
|
auto* object_to_pos_map = OutputSize() > OBJECT_TO_POS_MAP
|
|
? OperatorBase::Output<MapType64To32>(OBJECT_TO_POS_MAP)
|
|
: nullptr;
|
|
|
|
if (object_to_pos_map && !output_initialized) {
|
|
object_to_pos_map->clear();
|
|
}
|
|
|
|
auto* num_visited_tensor = Output(NUM_VISITED);
|
|
CAFFE_ENFORCE_EQ(1, num_visited_tensor->numel());
|
|
auto* num_visited = num_visited_tensor->template mutable_data<int64_t>();
|
|
if (!output_initialized) {
|
|
*num_visited = 0;
|
|
}
|
|
CAFFE_ENFORCE_GE(*num_visited, 0);
|
|
|
|
if (num_entries == 0) {
|
|
if (!output_initialized) {
|
|
// Get both shape and meta
|
|
output->CopyFrom(input, /* async */ true);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
const int64_t* object_id_data = nullptr;
|
|
std::set<int64_t> unique_object_ids;
|
|
if (InputSize() > OBJECT_ID) {
|
|
const auto& object_id = Input(OBJECT_ID);
|
|
CAFFE_ENFORCE_EQ(object_id.dim(), 1);
|
|
CAFFE_ENFORCE_EQ(object_id.numel(), num_entries);
|
|
object_id_data = object_id.template data<int64_t>();
|
|
unique_object_ids.insert(
|
|
object_id_data, object_id_data + object_id.numel());
|
|
}
|
|
|
|
const auto num_new_entries = countNewEntries(unique_object_ids);
|
|
auto num_to_copy = std::min<int32_t>(num_new_entries, numToCollect_);
|
|
auto output_batch_size = output_initialized ? output->size(0) : 0;
|
|
auto output_num =
|
|
std::min<size_t>(numToCollect_, output_batch_size + num_to_copy);
|
|
// output_num is >= output_batch_size
|
|
output->ExtendTo(output_num, 50);
|
|
if (pos_to_object) {
|
|
pos_to_object->ExtendTo(output_num, 50);
|
|
// ExtendTo doesn't zero-initialize tensors any more, explicitly clear
|
|
// the memory
|
|
memset(
|
|
pos_to_object->template mutable_data<int64_t>() +
|
|
output_batch_size * sizeof(int64_t),
|
|
0,
|
|
(output_num - output_batch_size) * sizeof(int64_t));
|
|
}
|
|
|
|
auto* output_data =
|
|
static_cast<char*>(output->raw_mutable_data(input.dtype()));
|
|
auto* pos_to_object_data = pos_to_object
|
|
? pos_to_object->template mutable_data<int64_t>()
|
|
: nullptr;
|
|
|
|
auto block_size = input.size_from_dim(1);
|
|
auto block_bytesize = block_size * input.itemsize();
|
|
const auto* input_data = static_cast<const char*>(input.raw_data());
|
|
|
|
const auto start_num_visited = *num_visited;
|
|
|
|
std::set<int64_t> eligible_object_ids;
|
|
if (object_to_pos_map) {
|
|
for (auto oid : unique_object_ids) {
|
|
if (!object_to_pos_map->count(oid)) {
|
|
eligible_object_ids.insert(oid);
|
|
}
|
|
}
|
|
}
|
|
|
|
for (int i = 0; i < num_entries; ++i) {
|
|
if (object_id_data && object_to_pos_map &&
|
|
!eligible_object_ids.count(object_id_data[i])) {
|
|
// Already in the pool or processed
|
|
continue;
|
|
}
|
|
if (object_id_data) {
|
|
eligible_object_ids.erase(object_id_data[i]);
|
|
}
|
|
int64_t pos = -1;
|
|
if (*num_visited < numToCollect_) {
|
|
// append
|
|
pos = *num_visited;
|
|
} else {
|
|
// uniform between [0, num_visited]
|
|
at::uniform_int_from_to_distribution<int64_t> uniformDist(*num_visited+1, 0);
|
|
pos = uniformDist(context_.RandGenerator());
|
|
if (pos >= numToCollect_) {
|
|
// discard
|
|
pos = -1;
|
|
}
|
|
}
|
|
|
|
if (pos < 0) {
|
|
// discard
|
|
CAFFE_ENFORCE_GE(*num_visited, numToCollect_);
|
|
} else {
|
|
// replace
|
|
context_.CopyItemsSameDevice(
|
|
input.dtype(),
|
|
block_size,
|
|
input_data + i * block_bytesize,
|
|
output_data + pos * block_bytesize);
|
|
|
|
if (object_id_data && pos_to_object_data && object_to_pos_map) {
|
|
auto old_oid = pos_to_object_data[pos];
|
|
auto new_oid = object_id_data[i];
|
|
pos_to_object_data[pos] = new_oid;
|
|
object_to_pos_map->erase(old_oid);
|
|
object_to_pos_map->emplace(new_oid, pos);
|
|
}
|
|
}
|
|
|
|
++(*num_visited);
|
|
}
|
|
// Sanity check
|
|
CAFFE_ENFORCE_EQ(*num_visited, start_num_visited + num_new_entries);
|
|
return true;
|
|
}
|
|
|
|
private:
|
|
// number of tensors to collect
|
|
int numToCollect_;
|
|
|
|
INPUT_TAGS(
|
|
RESERVOIR_IN,
|
|
NUM_VISITED_IN,
|
|
DATA,
|
|
MUTEX,
|
|
OBJECT_ID,
|
|
OBJECT_TO_POS_MAP_IN,
|
|
POS_TO_OBJECT_IN);
|
|
OUTPUT_TAGS(RESERVOIR, NUM_VISITED, OBJECT_TO_POS_MAP, POS_TO_OBJECT);
|
|
|
|
int32_t countNewEntries(const std::set<int64_t>& unique_object_ids) {
|
|
const auto& input = Input(DATA);
|
|
if (InputSize() <= OBJECT_ID) {
|
|
return input.size(0);
|
|
}
|
|
const auto& object_to_pos_map =
|
|
OperatorBase::Input<MapType64To32>(OBJECT_TO_POS_MAP_IN);
|
|
return std::count_if(
|
|
unique_object_ids.begin(),
|
|
unique_object_ids.end(),
|
|
[&object_to_pos_map](int64_t oid) {
|
|
return !object_to_pos_map.count(oid);
|
|
});
|
|
}
|
|
};
|
|
|
|
REGISTER_CPU_OPERATOR(ReservoirSampling, ReservoirSamplingOp<CPUContext>);
|
|
|
|
OPERATOR_SCHEMA(ReservoirSampling)
|
|
.NumInputs({4, 7})
|
|
.NumOutputs({2, 4})
|
|
.NumInputsOutputs([](int in, int out) { return in / 3 == out / 2; })
|
|
.EnforceInplace({{0, 0}, {1, 1}, {5, 2}, {6, 3}})
|
|
.SetDoc(R"DOC(
|
|
Collect `DATA` tensor into `RESERVOIR` of size `num_to_collect`. `DATA` is
|
|
assumed to be a batch.
|
|
|
|
In case where 'objects' may be repeated in data and you only want at most one
|
|
instance of each 'object' in the reservoir, `OBJECT_ID` can be given for
|
|
deduplication. If `OBJECT_ID` is given, then you also need to supply additional
|
|
book-keeping tensors. See input blob documentation for details.
|
|
|
|
This operator is thread-safe.
|
|
)DOC")
|
|
.Arg(
|
|
"num_to_collect",
|
|
"The number of random samples to append for each positive samples")
|
|
.Input(
|
|
0,
|
|
"RESERVOIR",
|
|
"The reservoir; should be initialized to empty tensor")
|
|
.Input(
|
|
1,
|
|
"NUM_VISITED",
|
|
"Number of examples seen so far; should be initialized to 0")
|
|
.Input(
|
|
2,
|
|
"DATA",
|
|
"Tensor to collect from. The first dimension is assumed to be batch "
|
|
"size. If the object to be collected is represented by multiple "
|
|
"tensors, use `PackRecords` to pack them into single tensor.")
|
|
.Input(3, "MUTEX", "Mutex to prevent data race")
|
|
.Input(
|
|
4,
|
|
"OBJECT_ID",
|
|
"(Optional, int64) If provided, used for deduplicating object in the "
|
|
"reservoir")
|
|
.Input(
|
|
5,
|
|
"OBJECT_TO_POS_MAP_IN",
|
|
"(Optional) Auxiliary bookkeeping map. This should be created from "
|
|
" `CreateMap` with keys of type int64 and values of type int32")
|
|
.Input(
|
|
6,
|
|
"POS_TO_OBJECT_IN",
|
|
"(Optional) Tensor of type int64 used for bookkeeping in deduplication")
|
|
.Output(0, "RESERVOIR", "Same as the input")
|
|
.Output(1, "NUM_VISITED", "Same as the input")
|
|
.Output(2, "OBJECT_TO_POS_MAP", "(Optional) Same as the input")
|
|
.Output(3, "POS_TO_OBJECT", "(Optional) Same as the input");
|
|
|
|
SHOULD_NOT_DO_GRADIENT(ReservoirSampling);
|
|
} // namespace
|
|
} // namespace caffe2
|