mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33868 Differential Revision: D20169059 Pulled By: ailzhang fbshipit-source-id: ec12226ae27ddd89fa5bacdd35151981ebfedcfd
110 lines
2.8 KiB
C++
110 lines
2.8 KiB
C++
#include <utility>
|
|
|
|
#include "caffe2/opt/annotations.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
using namespace nom::repr;
|
|
|
|
void Caffe2Annotation::setOperatorDef(const caffe2::OperatorDef& opDef) {
|
|
OpDef = opDef;
|
|
OpDefExists = true;
|
|
}
|
|
|
|
bool Caffe2Annotation::hasOperatorDef() const {
|
|
return OpDefExists;
|
|
}
|
|
|
|
const caffe2::OperatorDef& Caffe2Annotation::getOperatorDef() const {
|
|
CAFFE_ENFORCE(
|
|
OpDefExists,
|
|
"OperatorDef was never set. Use Caffe2Annotation::setOperatorDef.");
|
|
return OpDef;
|
|
}
|
|
caffe2::OperatorDef* Caffe2Annotation::getMutableOperatorDef() {
|
|
CAFFE_ENFORCE(
|
|
OpDefExists,
|
|
"OperatorDef was never set. Use Caffe2Annotation::setOperatorDef.");
|
|
return &OpDef;
|
|
}
|
|
|
|
// Distributed annotations
|
|
void Caffe2Annotation::setDeviceOption(const caffe2::DeviceOption& devOpt) {
|
|
*OpDef.mutable_device_option() = devOpt;
|
|
}
|
|
|
|
bool Caffe2Annotation::hasDeviceOption() const {
|
|
return OpDef.has_device_option();
|
|
}
|
|
|
|
const caffe2::DeviceOption& Caffe2Annotation::getDeviceOption() const {
|
|
CAFFE_ENFORCE(
|
|
hasDeviceOption(),
|
|
"DeviceOption was never set. Use Caffe2Annotation::setDeviceOption.");
|
|
return OpDef.device_option();
|
|
}
|
|
caffe2::DeviceOption* Caffe2Annotation::getMutableDeviceOption() {
|
|
CAFFE_ENFORCE(
|
|
hasDeviceOption(),
|
|
"DeviceOption was never set. Use Caffe2Annotation::setDeviceOption.");
|
|
return OpDef.mutable_device_option();
|
|
}
|
|
|
|
void Caffe2Annotation::setDevice(std::string device) {
|
|
Device = device;
|
|
}
|
|
const std::string Caffe2Annotation::getDevice() const {
|
|
return Device;
|
|
}
|
|
|
|
void Caffe2Annotation::setDeviceType(int device) {
|
|
DeviceType = device;
|
|
}
|
|
int Caffe2Annotation::getDeviceType() const {
|
|
return DeviceType;
|
|
}
|
|
|
|
void Caffe2Annotation::setParallelization(
|
|
Caffe2Annotation::ParallelizationScheme s,
|
|
int num) {
|
|
parallelization_scheme_ = s;
|
|
parallelization_ = num;
|
|
}
|
|
|
|
Caffe2Annotation::ParallelizationScheme
|
|
Caffe2Annotation::getParallelizationScheme() const {
|
|
return parallelization_scheme_;
|
|
}
|
|
|
|
int Caffe2Annotation::getParallelization() const {
|
|
return parallelization_;
|
|
}
|
|
|
|
void Caffe2Annotation::setKeyNode(NNGraph::NodeRef n) {
|
|
key_node_ = n;
|
|
}
|
|
const NNGraph::NodeRef& Caffe2Annotation::getKeyNode() const {
|
|
CAFFE_ENFORCE(key_node_, "No key node has been annotated");
|
|
return key_node_;
|
|
}
|
|
void Caffe2Annotation::setLengthNode(NNGraph::NodeRef n) {
|
|
length_node_ = n;
|
|
}
|
|
const NNGraph::NodeRef& Caffe2Annotation::getLengthNode() const {
|
|
CAFFE_ENFORCE(length_node_, "No length node has been annotated");
|
|
return length_node_;
|
|
}
|
|
|
|
void Caffe2Annotation::setComponentLevels(std::vector<std::string> components) {
|
|
component_levels_ = std::move(components);
|
|
}
|
|
std::vector<std::string> Caffe2Annotation::getComponentLevels() const {
|
|
return component_levels_;
|
|
}
|
|
|
|
bool Caffe2Annotation::classof(const Annotation* A) {
|
|
return A->getKind() == AnnotationKind::Caffe2;
|
|
}
|
|
|
|
} // namespace caffe2
|