pytorch/caffe2/core/test_utils.cc
Duc Ngo 10b16953d1 nomnigraph - easy - use new test utils in converter_nomnigraph_test (#15751)
Summary:
Use new test utils in converter_nomnigraph_test , and add utils to set device option name, external inputs, outputs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15751

Differential Revision: D13586228

Pulled By: duc0

fbshipit-source-id: ff809dd7bf9f30641ce2a6fef7e2810f005521c2
2019-01-14 18:38:38 -08:00

117 lines
3.3 KiB
C++

#include "caffe2/core/tensor.h"
#include "caffe2/core/workspace.h"
#include "test_utils.h"
namespace {
template <typename T>
void assertTensorEqualsWithType(
const caffe2::TensorCPU& tensor1,
const caffe2::TensorCPU& tensor2) {
CAFFE_ENFORCE_EQ(tensor1.sizes(), tensor2.sizes());
for (auto idx = 0; idx < tensor1.numel(); ++idx) {
CAFFE_ENFORCE_EQ(tensor1.data<T>()[idx], tensor2.data<T>()[idx]);
}
}
} // namespace
namespace caffe2 {
namespace testing {
// Asserts that two float values are close within epsilon.
void assertNear(float value1, float value2, float epsilon) {
// These two enforces will give good debug messages.
CAFFE_ENFORCE_LE(value1, value2 + epsilon);
CAFFE_ENFORCE_GE(value1, value2 - epsilon);
}
void assertTensorEquals(const TensorCPU& tensor1, const TensorCPU& tensor2) {
CAFFE_ENFORCE_EQ(tensor1.sizes(), tensor2.sizes());
if (tensor1.IsType<float>()) {
CAFFE_ENFORCE(tensor2.IsType<float>());
assertTensorEqualsWithType<float>(tensor1, tensor2);
} else if (tensor1.IsType<int>()) {
CAFFE_ENFORCE(tensor2.IsType<int>());
assertTensorEqualsWithType<int>(tensor1, tensor2);
} else if (tensor1.IsType<int64_t>()) {
CAFFE_ENFORCE(tensor2.IsType<int64_t>());
assertTensorEqualsWithType<int64_t>(tensor1, tensor2);
}
// Add more types if needed.
}
void assertTensorListEquals(
const std::vector<std::string>& tensorNames,
const Workspace& workspace1,
const Workspace& workspace2) {
for (const std::string& tensorName : tensorNames) {
CAFFE_ENFORCE(workspace1.HasBlob(tensorName));
CAFFE_ENFORCE(workspace2.HasBlob(tensorName));
auto& tensor1 = getTensor(workspace1, tensorName);
auto& tensor2 = getTensor(workspace2, tensorName);
assertTensorEquals(tensor1, tensor2);
}
}
const caffe2::Tensor& getTensor(
const caffe2::Workspace& workspace,
const std::string& name) {
CAFFE_ENFORCE(workspace.HasBlob(name));
return workspace.GetBlob(name)->Get<caffe2::Tensor>();
}
caffe2::Tensor* createTensor(
const std::string& name,
caffe2::Workspace* workspace) {
return BlobGetMutableTensor(workspace->CreateBlob(name), caffe2::CPU);
}
caffe2::OperatorDef* createOperator(
const std::string& type,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
caffe2::NetDef* net) {
auto* op = net->add_op();
op->set_type(type);
for (const auto& in : inputs) {
op->add_input(in);
}
for (const auto& out : outputs) {
op->add_output(out);
}
return op;
}
NetMutator& NetMutator::newOp(
const std::string& type,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
lastCreatedOp_ = createOperator(type, inputs, outputs, net_);
return *this;
}
NetMutator& NetMutator::externalInputs(
const std::vector<std::string>& externalInputs) {
for (auto& blob : externalInputs) {
net_->add_external_input(blob);
}
return *this;
}
NetMutator& NetMutator::externalOutputs(
const std::vector<std::string>& externalOutputs) {
for (auto& blob : externalOutputs) {
net_->add_external_output(blob);
}
return *this;
}
NetMutator& NetMutator::setDeviceOptionName(const std::string& name) {
CAFFE_ENFORCE(lastCreatedOp_ != nullptr);
lastCreatedOp_->mutable_device_option()->set_node_name(name);
return *this;
}
} // namespace testing
} // namespace caffe2