mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/9939 Pull Request resolved: https://github.com/facebookresearch/weakly-supervised-action-detection/pull/13 Pull Request resolved: https://github.com/pytorch/translate/pull/166 Pull Request resolved: https://github.com/pytorch/pytorch/pull/9125 Closes https://github.com/pytorch/pytorch/pull/9125 Use inheritance for polymorphism, and remove template parameter This is to change the templating in call sites, the core implementations will change later Before Caffe2 Tensor class was compile-time fixed to bind to a particular device/context. With this change, we're making it a runtime property (stored inside the tensor), but preserve the same semantics. For example, one has to specify device type in order to create a Tensor - there are no uninitialized tensors. More specifically the changes are: 1. We added an extra argument *DeviceType* to most of the constructors of the tensor, e.g. (Tensor(DeviceType type)), 2. Semantics of constructor Tensor(const Tensor<SrcContext>& src, ContextForCopy* context); is changed, in this constructor, the second context is passed in to enable us to call the templated Copy function, it could be in a different context as source and target previously, now we'll enforce that the context should have same device type as src, if it is provided. 3. To preserve 'get-or-construct' semantics of Blob, we added specialized getter Blob::GetMutableTensor that verifies both that Blob contains a Tensor and that it's of a correct type 4. Specifically, Tensor type is not default-constructible any more (as we don't have unknown device tensors) and thus some of the code handling STL containers needs to change Note: Some changes are postponed just to keep this diff a bit smaller. Please see `TODO`s. Reviewed By: ezyang, houseroad Differential Revision: D9024330 fbshipit-source-id: e0b8295d2dc6ebe2963383ded5af799ad17164ba
146 lines
3.2 KiB
C++
146 lines
3.2 KiB
C++
|
|
#pragma once
|
|
|
|
#include <chrono>
|
|
#include <string>
|
|
|
|
#include "caffe2/core/db.h"
|
|
#include "caffe2/core/logging.h"
|
|
#include "caffe2/core/stats.h"
|
|
#include "caffe2/queue/blobs_queue.h"
|
|
|
|
namespace caffe2 {
|
|
namespace db {
|
|
|
|
namespace {
|
|
const std::string& GetStringFromBlob(Blob* blob) {
|
|
if (blob->template IsType<string>()) {
|
|
return blob->template Get<string>();
|
|
} else if (blob->template IsType<Tensor>()) {
|
|
return *blob->template Get<Tensor>().template data<string>();
|
|
} else {
|
|
CAFFE_THROW("Unsupported Blob type");
|
|
}
|
|
}
|
|
}
|
|
|
|
class BlobsQueueDBCursor : public Cursor {
|
|
public:
|
|
explicit BlobsQueueDBCursor(
|
|
std::shared_ptr<BlobsQueue> queue,
|
|
int key_blob_index,
|
|
int value_blob_index,
|
|
float timeout_secs)
|
|
: queue_(queue),
|
|
key_blob_index_(key_blob_index),
|
|
value_blob_index_(value_blob_index),
|
|
timeout_secs_(timeout_secs),
|
|
inited_(false),
|
|
valid_(false) {
|
|
LOG(INFO) << "BlobsQueueDBCursor constructed";
|
|
CAFFE_ENFORCE(queue_ != nullptr, "queue is null");
|
|
CAFFE_ENFORCE(value_blob_index_ >= 0, "value_blob_index < 0");
|
|
}
|
|
|
|
virtual ~BlobsQueueDBCursor() {}
|
|
|
|
void Seek(const string& /* unused */) override {
|
|
CAFFE_THROW("Seek is not supported.");
|
|
}
|
|
|
|
bool SupportsSeek() override {
|
|
return false;
|
|
}
|
|
|
|
void SeekToFirst() override {
|
|
// not applicable
|
|
}
|
|
|
|
void Next() override {
|
|
unique_ptr<Blob> blob = make_unique<Blob>();
|
|
vector<Blob*> blob_vector{blob.get()};
|
|
auto success = queue_->blockingRead(blob_vector, timeout_secs_);
|
|
if (!success) {
|
|
LOG(ERROR) << "Timed out reading from BlobsQueue or it is closed";
|
|
valid_ = false;
|
|
return;
|
|
}
|
|
|
|
if (key_blob_index_ >= 0) {
|
|
key_ = GetStringFromBlob(blob_vector[key_blob_index_]);
|
|
}
|
|
value_ = GetStringFromBlob(blob_vector[value_blob_index_]);
|
|
valid_ = true;
|
|
}
|
|
|
|
string key() override {
|
|
if (!inited_) {
|
|
Next();
|
|
inited_ = true;
|
|
}
|
|
return key_;
|
|
}
|
|
|
|
string value() override {
|
|
if (!inited_) {
|
|
Next();
|
|
inited_ = true;
|
|
}
|
|
return value_;
|
|
}
|
|
|
|
bool Valid() override {
|
|
return valid_;
|
|
}
|
|
|
|
private:
|
|
std::shared_ptr<BlobsQueue> queue_;
|
|
int key_blob_index_;
|
|
int value_blob_index_;
|
|
float timeout_secs_;
|
|
bool inited_;
|
|
string key_;
|
|
string value_;
|
|
bool valid_;
|
|
};
|
|
|
|
class BlobsQueueDB : public DB {
|
|
public:
|
|
BlobsQueueDB(
|
|
const string& source,
|
|
Mode mode,
|
|
std::shared_ptr<BlobsQueue> queue,
|
|
int key_blob_index = -1,
|
|
int value_blob_index = 0,
|
|
float timeout_secs = 0.0)
|
|
: DB(source, mode),
|
|
queue_(queue),
|
|
key_blob_index_(key_blob_index),
|
|
value_blob_index_(value_blob_index),
|
|
timeout_secs_(timeout_secs) {
|
|
LOG(INFO) << "BlobsQueueDB constructed";
|
|
}
|
|
|
|
virtual ~BlobsQueueDB() {
|
|
Close();
|
|
}
|
|
|
|
void Close() override {}
|
|
unique_ptr<Cursor> NewCursor() override {
|
|
return make_unique<BlobsQueueDBCursor>(
|
|
queue_, key_blob_index_, value_blob_index_, timeout_secs_);
|
|
}
|
|
|
|
unique_ptr<Transaction> NewTransaction() override {
|
|
CAFFE_THROW("Not implemented.");
|
|
}
|
|
|
|
private:
|
|
std::shared_ptr<BlobsQueue> queue_;
|
|
int key_blob_index_;
|
|
int value_blob_index_;
|
|
float timeout_secs_;
|
|
};
|
|
} // namespace db
|
|
} // namespace caffe2
|