mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
89 lines
2.3 KiB
C++
89 lines
2.3 KiB
C++
#pragma once
|
|
|
|
#include <memory>
|
|
#include "blobs_queue.h"
|
|
#include "caffe2/core/operator.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
template <typename Context>
|
|
class CreateBlobsQueueOp final : public Operator<Context> {
|
|
public:
|
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
|
|
CreateBlobsQueueOp(const OperatorDef& operator_def, Workspace* ws)
|
|
: Operator<Context>(operator_def, ws), ws_(ws) {}
|
|
|
|
bool RunOnDevice() override {
|
|
const auto capacity =
|
|
OperatorBase::template GetSingleArgument<int>("capacity", 1);
|
|
const auto numBlobs =
|
|
OperatorBase::template GetSingleArgument<int>("num_blobs", 1);
|
|
const auto enforceUniqueName =
|
|
OperatorBase::template GetSingleArgument<int>(
|
|
"enforce_unique_name", false);
|
|
CHECK_EQ(def().output().size(), 1);
|
|
const auto name = def().output().Get(0);
|
|
auto queuePtr = Operator<Context>::Outputs()[0]
|
|
->template GetMutable<std::shared_ptr<BlobsQueue>>();
|
|
CHECK(queuePtr);
|
|
*queuePtr = std::make_shared<BlobsQueue>(
|
|
ws_, name, capacity, numBlobs, enforceUniqueName);
|
|
return true;
|
|
}
|
|
|
|
private:
|
|
Workspace* ws_{nullptr};
|
|
};
|
|
|
|
template <typename Context>
|
|
class EnqueueBlobsOp final : public Operator<Context> {
|
|
public:
|
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
using Operator<Context>::Operator;
|
|
bool RunOnDevice() override {
|
|
CHECK_GT(InputSize(), 1);
|
|
auto queue = Operator<Context>::Inputs()[0]
|
|
->template Get<std::shared_ptr<BlobsQueue>>();
|
|
CHECK(queue);
|
|
return queue->blockingWrite(this->Outputs());
|
|
}
|
|
|
|
private:
|
|
};
|
|
|
|
template <typename Context>
|
|
class DequeueBlobsOp final : public Operator<Context> {
|
|
public:
|
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
using Operator<Context>::Operator;
|
|
bool RunOnDevice() override {
|
|
CHECK_EQ(InputSize(), 1);
|
|
auto queue =
|
|
OperatorBase::Inputs()[0]->template Get<std::shared_ptr<BlobsQueue>>();
|
|
CHECK(queue);
|
|
return queue->blockingRead(this->Outputs());
|
|
}
|
|
|
|
private:
|
|
};
|
|
|
|
template <typename Context>
|
|
class CloseBlobsQueueOp final : public Operator<Context> {
|
|
public:
|
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
using Operator<Context>::Operator;
|
|
bool RunOnDevice() override {
|
|
CHECK_EQ(InputSize(), 1);
|
|
auto queue =
|
|
OperatorBase::Inputs()[0]->template Get<std::shared_ptr<BlobsQueue>>();
|
|
CHECK(queue);
|
|
queue->close();
|
|
queue.reset();
|
|
return true;
|
|
}
|
|
|
|
private:
|
|
};
|
|
}
|