mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Use C++11’s override and remove virtual where applicable. Change are automatically generated. Reviewed By: Orvid Differential Revision: D14086124 fbshipit-source-id: 2005227d095d776ca3b4309a57f54e25782b9b58
117 lines
2.9 KiB
C++
117 lines
2.9 KiB
C++
#include <atomic>
|
|
#include <condition_variable>
|
|
#include <mutex>
|
|
#include <thread> // NOLINT
|
|
|
|
#include "caffe2/core/db.h"
|
|
#include "caffe2/utils/zmq_helper.h"
|
|
#include "caffe2/core/logging.h"
|
|
|
|
namespace caffe2 {
|
|
namespace db {
|
|
|
|
class ZmqDBCursor : public Cursor {
|
|
public:
|
|
explicit ZmqDBCursor(const string& source)
|
|
: source_(source), socket_(ZMQ_PULL),
|
|
prefetched_(false), finalize_(false) {
|
|
socket_.Connect(source_);
|
|
// Start prefetching thread.
|
|
prefetch_thread_.reset(
|
|
new std::thread([this] { this->Prefetch(); }));
|
|
// obtain the first value.
|
|
Next();
|
|
}
|
|
|
|
~ZmqDBCursor() override {
|
|
finalize_ = true;
|
|
prefetched_ = false;
|
|
producer_.notify_one();
|
|
// Wait for the prefetch thread to finish elegantly.
|
|
prefetch_thread_->join();
|
|
socket_.Disconnect(source_);
|
|
}
|
|
|
|
void Seek(const string& /*key*/) override { /* do nothing */
|
|
}
|
|
|
|
void SeekToFirst() override { /* do nothing */ }
|
|
|
|
void Next() override {
|
|
std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
|
|
while (!prefetched_) consumer_.wait(lock);
|
|
key_ = prefetch_key_;
|
|
value_ = prefetch_value_;
|
|
prefetched_ = false;
|
|
producer_.notify_one();
|
|
}
|
|
|
|
string key() override { return key_; }
|
|
string value() override { return value_; }
|
|
bool Valid() override { return true; }
|
|
|
|
private:
|
|
|
|
void Prefetch() {
|
|
while (!finalize_) {
|
|
std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
|
|
while (prefetched_) producer_.wait(lock);
|
|
if (finalize_) {
|
|
return;
|
|
}
|
|
ZmqMessage msg;
|
|
socket_.RecvTillSuccess(&msg);
|
|
prefetch_key_.assign(static_cast<char*>(msg.data()), msg.size());
|
|
socket_.RecvTillSuccess(&msg);
|
|
prefetch_value_.assign(static_cast<char*>(msg.data()), msg.size());
|
|
prefetched_ = true;
|
|
consumer_.notify_one();
|
|
}
|
|
}
|
|
|
|
string source_;
|
|
ZmqSocket socket_;
|
|
string key_;
|
|
string value_;
|
|
string prefetch_key_;
|
|
string prefetch_value_;
|
|
|
|
unique_ptr<std::thread> prefetch_thread_;
|
|
std::mutex prefetch_access_mutex_;
|
|
std::condition_variable producer_, consumer_;
|
|
std::atomic<bool> prefetched_;
|
|
// finalize_ is used to tell the prefetcher to quit.
|
|
std::atomic<bool> finalize_;
|
|
};
|
|
|
|
class ZmqDB : public DB {
|
|
public:
|
|
ZmqDB(const string& source, Mode mode)
|
|
: DB(source, mode), source_(source) {
|
|
CAFFE_ENFORCE(mode == READ, "ZeroMQ DB only supports read mode.");
|
|
}
|
|
|
|
~ZmqDB() override {}
|
|
|
|
void Close() override {}
|
|
|
|
unique_ptr<Cursor> NewCursor() override {
|
|
return make_unique<ZmqDBCursor>(source_);
|
|
}
|
|
|
|
unique_ptr<Transaction> NewTransaction() override {
|
|
CAFFE_THROW("ZeroMQ DB does not support writing with a transaction.");
|
|
return nullptr; // dummy placeholder to suppress old compiler warnings.
|
|
}
|
|
|
|
private:
|
|
string source_;
|
|
};
|
|
|
|
REGISTER_CAFFE2_DB(ZmqDB, ZmqDB);
|
|
// For lazy-minded, one can also call with lower-case name.
|
|
REGISTER_CAFFE2_DB(zmqdb, ZmqDB);
|
|
|
|
} // namespace db
|
|
} // namespace caffe2
|