#pragma once #include #include "caffe2/core/context.h" #include "caffe2/core/operator.h" #include "caffe2/utils/math.h" #include namespace caffe2 { template class KeySplitOp : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit KeySplitOp(Args&&... args) : Operator(std::forward(args)...), categorical_limit_( this->template GetSingleArgument("categorical_limit", 0)) { CAFFE_ENFORCE_GT(categorical_limit_, 0); } bool RunOnDevice() override { auto& keys = Input(0); const auto N = keys.numel(); const T *const keys_data = keys.template data(); std::vector counts(categorical_limit_); std::vector eids(categorical_limit_); for (const auto k : c10::irange(categorical_limit_)) { counts[k] = 0; } for (const auto i : c10::irange(N)) { const auto k = keys_data[i]; CAFFE_ENFORCE_GT(categorical_limit_, k); CAFFE_ENFORCE_GE(k, 0); counts[k]++; } for (const auto k : c10::irange(categorical_limit_)) { auto *const eid = Output(k, {counts[k]}, at::dtype()); eids[k] = eid->template mutable_data(); counts[k] = 0; } for (const auto i : c10::irange(N)) { const auto k = keys_data[i]; eids[k][counts[k]++] = i; } return true; } private: int categorical_limit_; }; } // namespace caffe2