pytorch/caffe2/onnx/ssa_test.cc
Nikita Shulga a9b0a921d5 Disable avoid-non-const-global-variables lint check (#62008)
Summary:
As GoogleTest `TEST` macro is non-compliant with it as well as `DEFINE_DISPATCH`

All changes but the ones to `.clang-tidy` are generated using following script:
```
for i in `find . -type f -iname "*.c*" -or -iname "*.h"|xargs grep cppcoreguidelines-avoid-non-const-global-variables|cut -f1 -d:|sort|uniq`;  do sed -i "/\/\/ NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)/d" $i; done
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62008

Reviewed By: driazati, r-barnes

Differential Revision: D29838584

Pulled By: malfet

fbshipit-source-id: 1b2f8602c945bd4ce50a9bfdd204755556e31d13
2021-07-22 18:04:40 -07:00

81 lines
2.2 KiB
C++

#include "caffe2/core/common.h"
#include "caffe2/core/logging.h"
#include "caffe2/onnx/onnx_exporter.h"
#include <gtest/gtest.h>
#include <string>
#include <tuple>
#include <unordered_map>
TEST(SsaTest, ConvReluInplace) {
caffe2::NetDef net;
auto* op = net.add_op();
op->set_type("Conv");
op->add_input("X");
op->add_input("W");
op->add_input("b");
op->add_output("Y");
op = net.add_op();
op->set_type("Relu");
op->add_input("Y");
op->add_output("Y");
net.add_external_input("X");
net.add_external_output("Y");
std::unordered_map<std::string, std::string> input_mapping =
caffe2::onnx::SsaRewrite(nullptr, &net);
for (const auto& net_op : net.op()) {
std::unordered_set<std::string> inputs;
for (const auto& i : net_op.input()) {
inputs.emplace(i);
}
for (const auto& o : net_op.output()) {
EXPECT_TRUE(inputs.count(o) == 0);
}
}
EXPECT_EQ(net.op(0).output(0), net.op(1).input(0));
EXPECT_EQ("X", input_mapping.at(net.external_input(0)));
EXPECT_EQ("Y", net.external_output(0));
}
TEST(SsaTest, FC_Relu_FC_InPlace_Output) {
caffe2::NetDef net;
auto* op = net.add_op();
op->set_type("FC");
op->add_input("X");
op->add_input("W0");
op->add_input("b0");
op->add_output("Y");
op = net.add_op();
op->set_type("Relu");
op->add_input("Y");
op->add_output("Y");
op = net.add_op();
op->set_type("FC");
op->add_input("Y");
op->add_input("W2");
op->add_input("b2");
op->add_output("Z");
net.add_external_input("X");
net.add_external_output("Y");
net.add_external_output("Z");
std::unordered_map<std::string, std::string> input_mapping =
caffe2::onnx::SsaRewrite(nullptr, &net);
for (const auto& net_op : net.op()) {
std::unordered_set<std::string> inputs;
for (const auto& i : net_op.input()) {
inputs.emplace(i);
}
for (const auto& o : net_op.output()) {
EXPECT_TRUE(inputs.count(o) == 0);
}
}
EXPECT_EQ(net.op(0).output(0), net.op(1).input(0));
EXPECT_EQ("Y", net.op(2).input(0));
EXPECT_EQ("Y_0", net.op(1).input(0));
EXPECT_EQ("X", input_mapping.at(net.external_input(0)));
EXPECT_EQ("Y", net.external_output(0));
EXPECT_EQ("Z", net.external_output(1));
}