pytorch/caffe2/operators/counter_ops.h
Jane Xu 71ca600af9 Renaming CAFFE2_API to TORCH_API (#49496)
Summary:
Since caffe2 and torch have been consolidated, CAFFE2_API should be merged with TORCH_API. Addresses a TODO.

Manually edited some references of the removed `CAFFE2_API`:
* `CONTRIBUTING.md`
* `caffe2/proto/CMakeLists.txt`
* `cmake/ProtoBuf.cmake`
* `c10/macros/Export.h`
* `torch/csrc/WindowsTorchApiMacro.h`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/49496

Reviewed By: malfet, samestep

Differential Revision: D25600726

Pulled By: janeyx99

fbshipit-source-id: 7e068d959e397ac183c097d7e9a9afeca5ddd782
2020-12-18 10:54:50 -08:00

165 lines
4.5 KiB
C++

#ifndef CAFFE2_OPERATORS_COUNTER_OPS_H
#define CAFFE2_OPERATORS_COUNTER_OPS_H
#include <atomic>
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
namespace caffe2 {
template <typename T>
class TORCH_API Counter {
public:
explicit Counter(T count) : count_(count) {}
bool countDown() {
if (count_-- > 0) {
return false;
}
return true;
}
T countUp() {
return count_++;
}
T retrieve() const {
return count_.load();
}
T checkIfDone() const {
return (count_.load() <= 0);
}
T reset(T init_count) {
return count_.exchange(init_count);
}
private:
std::atomic<T> count_;
};
// TODO(jiayq): deprecate these ops & consolidate them with IterOp/AtomicIterOp
template <typename T, class Context>
class CreateCounterOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit CreateCounterOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
init_count_(this->template GetSingleArgument<T>("init_count", 0)) {
CAFFE_ENFORCE_LE(0, init_count_, "negative init_count is not permitted.");
}
bool RunOnDevice() override {
*this->template Output<std::unique_ptr<Counter<T>>>(0) =
std::unique_ptr<Counter<T>>(new Counter<T>(init_count_));
return true;
}
private:
T init_count_ = 0;
};
template <typename T, class Context>
class ResetCounterOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit ResetCounterOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
init_count_(this->template GetSingleArgument<T>("init_count", 0)) {
CAFFE_ENFORCE_LE(0, init_count_, "negative init_count is not permitted.");
}
bool RunOnDevice() override {
auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0);
auto previous = counterPtr->reset(init_count_);
if (OutputSize() == 1) {
auto* output = Output(0);
output->Resize();
*output->template mutable_data<T>() = previous;
}
return true;
}
private:
T init_count_;
};
// Will always use TensorCPU regardless the Context
template <typename T, class Context>
class CountDownOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit CountDownOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {}
bool RunOnDevice() override {
auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0);
auto* output = Output(0);
output->Resize(std::vector<int>{});
*output->template mutable_data<bool>() = counterPtr->countDown();
return true;
}
};
// Will always use TensorCPU regardless the Context
template <typename T, class Context>
class CheckCounterDoneOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit CheckCounterDoneOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {}
bool RunOnDevice() override {
auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0);
auto* output = Output(0);
output->Resize(std::vector<int>{});
*output->template mutable_data<bool>() = counterPtr->checkIfDone();
return true;
}
};
// Will always use TensorCPU regardless the Context
template <typename T, class Context>
class CountUpOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit CountUpOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {}
bool RunOnDevice() override {
auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0);
auto* output = Output(0);
output->Resize(std::vector<int>{});
*output->template mutable_data<T>() = counterPtr->countUp();
return true;
}
};
// Will always use TensorCPU regardless the Context
template <typename T, class Context>
class RetrieveCountOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit RetrieveCountOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {}
bool RunOnDevice() override {
auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0);
auto* output = Output(0);
output->Resize(std::vector<int>{});
*output->template mutable_data<T>() = counterPtr->retrieve();
return true;
}
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_COUNTER_OPS_H_