#include "caffe2/core/context.h" #include "caffe2/core/operator.h" #include "caffe2/core/tensor.h" #include "caffe2/predictor/predictor.h" #include "caffe2/utils/math.h" #include namespace caffe2 { namespace { const char* predictSpec = R"DOC( name: "predict" type: "dag" external_input: "data" external_input: "W" external_input: "b" external_output: "y" op { input: "data" input: "W" input: "b" output: "y" type: "FC" } )DOC"; const char* initSpec = R"DOC( name: "init" type: "dag" op { type: "ConstantFill" output: "W" arg { name: "shape" ints: 10 ints: 4 } arg { name: "value" f: 2.0 } } op { type: "ConstantFill" output: "b" arg { name: "shape" ints: 10 } arg { name: "value" f: 2.0 } } )DOC"; const char* metaSpec = R"DOC( blobs { key: "INPUTS_BLOB_TYPE" value: "data" } blobs { key: "OUTPUTS_BLOB_TYPE" value: "y" } nets { key: "GLOBAL_INIT_NET_TYPE" value: { name: "init" type: "dag" op { type: "ConstantFill" output: "data" arg { name: "shape" ints: 1 ints: 4 } arg { name: "value" f: 2.0 } } op { type: "ConstantFill" output: "W" arg { name: "shape" ints: 10 ints: 4 } arg { name: "value" f: 2.0 } } op { type: "ConstantFill" output: "b" arg { name: "shape" ints: 10 } arg { name: "value" f: 2.0 } } } } nets { key: "PREDICT_NET_TYPE" value: { name: "predict" type: "dag" external_input: "data" external_input: "W" external_input: "b" external_output: "y" op { input: "data" input: "W" input: "b" output: "y" type: "FC" } } } )DOC"; std::unique_ptr randomTensor( const std::vector& dims, CPUContext* ctx) { auto blob = make_unique(); auto* t = BlobGetMutableTensor(blob.get(), CPU); t->Resize(dims); math::RandUniform( t->numel(), -1.0, 1.0, t->template mutable_data(), ctx); return blob; } NetDef parseNetDef(const std::string& value) { NetDef def; CAFFE_ENFORCE( TextFormat::ParseFromString(value, &def), "Failed to parse NetDef with value: ", value); return def; }; MetaNetDef parseMetaNetDef(const std::string& value) { MetaNetDef def; CAFFE_ENFORCE( TextFormat::ParseFromString(value, &def), "Failed to parse NetDef with value: ", value); return def; } } // namespace class PredictorTest : public testing::Test { public: void SetUp() override { DeviceOption op; op.set_random_seed(1701); ctx_ = std::make_unique(op); NetDef init, run; p_ = std::make_unique( makePredictorConfig(parseNetDef(initSpec), parseNetDef(predictSpec))); } std::unique_ptr ctx_; std::unique_ptr p_; }; TEST_F(PredictorTest, SimpleBatchSized) { auto inputData = randomTensor({1, 4}, ctx_.get()); Predictor::TensorList input; auto tensor = BlobGetMutableTensor(inputData.get(), CPU); input.emplace_back(tensor->Alias()); Predictor::TensorList output; (*p_)(input, &output); EXPECT_EQ(output.size(), 1); EXPECT_EQ(output.front().sizes().size(), 2); EXPECT_EQ(output.front().size(0), 1); EXPECT_EQ(output.front().size(1), 10); EXPECT_NEAR(output.front().data()[4], 4.9556, 1E-4); } TEST_F(PredictorTest, SimpleBatchSizedMapInput) { auto inputData = randomTensor({1, 4}, ctx_.get()); Predictor::TensorMap input; auto tensor = BlobGetMutableTensor(inputData.get(), CPU); input.emplace("data", tensor->Alias()); Predictor::TensorList output; (*p_)(input, &output); EXPECT_EQ(output.size(), 1); EXPECT_EQ(output.front().sizes().size(), 2); EXPECT_EQ(output.front().size(0), 1); EXPECT_EQ(output.front().size(1), 10); EXPECT_NEAR(output.front().data()[4], 4.9556, 1E-4); } } // namespace caffe2