pytorch/caffe2/queue/blobs_queue_db.h
Jerry Zhang aebf3b47ae Remove template parameter from Tensor (#9939)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9939

Pull Request resolved: https://github.com/facebookresearch/weakly-supervised-action-detection/pull/13

Pull Request resolved: https://github.com/pytorch/translate/pull/166

Pull Request resolved: https://github.com/pytorch/pytorch/pull/9125

Closes https://github.com/pytorch/pytorch/pull/9125

Use inheritance for polymorphism, and remove template parameter
This is to change the templating in call sites, the core implementations will change later

Before Caffe2 Tensor class was compile-time fixed to bind to a particular device/context. With this change, we're making it a runtime property (stored inside the tensor), but preserve the same semantics. For example, one has to specify device type in order to create a Tensor - there are no uninitialized tensors. More specifically the changes are:

1. We added an extra argument *DeviceType* to most of the constructors of the tensor, e.g. (Tensor(DeviceType type)),
2. Semantics of constructor Tensor(const Tensor<SrcContext>& src, ContextForCopy* context); is changed, in this constructor, the second context is passed in to enable us to call the templated Copy function, it could be in a different context as source and target previously, now we'll enforce that the context should have same device type as src, if it is provided.
3. To preserve 'get-or-construct' semantics of Blob, we added specialized getter Blob::GetMutableTensor that verifies both that Blob contains a Tensor and that it's of a correct type
4. Specifically, Tensor type is not default-constructible any more (as we don't have unknown device tensors) and thus some of the code handling STL containers needs to change

Note: Some changes are postponed just to keep this diff a bit smaller. Please see `TODO`s.

Reviewed By: ezyang, houseroad

Differential Revision: D9024330

fbshipit-source-id: e0b8295d2dc6ebe2963383ded5af799ad17164ba
2018-07-27 10:56:39 -07:00

146 lines
3.2 KiB
C++

#pragma once
#include <chrono>
#include <string>
#include "caffe2/core/db.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/stats.h"
#include "caffe2/queue/blobs_queue.h"
namespace caffe2 {
namespace db {
namespace {
const std::string& GetStringFromBlob(Blob* blob) {
if (blob->template IsType<string>()) {
return blob->template Get<string>();
} else if (blob->template IsType<Tensor>()) {
return *blob->template Get<Tensor>().template data<string>();
} else {
CAFFE_THROW("Unsupported Blob type");
}
}
}
class BlobsQueueDBCursor : public Cursor {
public:
explicit BlobsQueueDBCursor(
std::shared_ptr<BlobsQueue> queue,
int key_blob_index,
int value_blob_index,
float timeout_secs)
: queue_(queue),
key_blob_index_(key_blob_index),
value_blob_index_(value_blob_index),
timeout_secs_(timeout_secs),
inited_(false),
valid_(false) {
LOG(INFO) << "BlobsQueueDBCursor constructed";
CAFFE_ENFORCE(queue_ != nullptr, "queue is null");
CAFFE_ENFORCE(value_blob_index_ >= 0, "value_blob_index < 0");
}
virtual ~BlobsQueueDBCursor() {}
void Seek(const string& /* unused */) override {
CAFFE_THROW("Seek is not supported.");
}
bool SupportsSeek() override {
return false;
}
void SeekToFirst() override {
// not applicable
}
void Next() override {
unique_ptr<Blob> blob = make_unique<Blob>();
vector<Blob*> blob_vector{blob.get()};
auto success = queue_->blockingRead(blob_vector, timeout_secs_);
if (!success) {
LOG(ERROR) << "Timed out reading from BlobsQueue or it is closed";
valid_ = false;
return;
}
if (key_blob_index_ >= 0) {
key_ = GetStringFromBlob(blob_vector[key_blob_index_]);
}
value_ = GetStringFromBlob(blob_vector[value_blob_index_]);
valid_ = true;
}
string key() override {
if (!inited_) {
Next();
inited_ = true;
}
return key_;
}
string value() override {
if (!inited_) {
Next();
inited_ = true;
}
return value_;
}
bool Valid() override {
return valid_;
}
private:
std::shared_ptr<BlobsQueue> queue_;
int key_blob_index_;
int value_blob_index_;
float timeout_secs_;
bool inited_;
string key_;
string value_;
bool valid_;
};
class BlobsQueueDB : public DB {
public:
BlobsQueueDB(
const string& source,
Mode mode,
std::shared_ptr<BlobsQueue> queue,
int key_blob_index = -1,
int value_blob_index = 0,
float timeout_secs = 0.0)
: DB(source, mode),
queue_(queue),
key_blob_index_(key_blob_index),
value_blob_index_(value_blob_index),
timeout_secs_(timeout_secs) {
LOG(INFO) << "BlobsQueueDB constructed";
}
virtual ~BlobsQueueDB() {
Close();
}
void Close() override {}
unique_ptr<Cursor> NewCursor() override {
return make_unique<BlobsQueueDBCursor>(
queue_, key_blob_index_, value_blob_index_, timeout_secs_);
}
unique_ptr<Transaction> NewTransaction() override {
CAFFE_THROW("Not implemented.");
}
private:
std::shared_ptr<BlobsQueue> queue_;
int key_blob_index_;
int value_blob_index_;
float timeout_secs_;
};
} // namespace db
} // namespace caffe2