mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: ## Motivation This PR upgrades MKL-DNN from v0.20 to DNNL v1.2 and resolves https://github.com/pytorch/pytorch/issues/30300. DNNL (Deep Neural Network Library) is the new brand of MKL-DNN, which improves performance, quality, and usability over the old version. This PR focuses on the migration of all existing functionalities, including minor fixes, performance improvement and code clean up. It serves as the cornerstone of our future efforts to accommodate new features like OpenCL support, BF16 training, INT8 inference, etc. and to let the Pytorch community derive more benefits from the Intel Architecture. <br> ## What's included? Even DNNL has many breaking changes to the API, we managed to absorb most of them in ideep. This PR contains minimalist changes to the integration code in pytorch. Below is a summary of the changes: <br> **General:** 1. Replace op-level allocator with global-registered allocator ``` // before ideep::sum::compute<AllocForMKLDNN>(scales, {x, y}, z); // after ideep::sum::compute(scales, {x, y}, z); ``` The allocator is now being registeted at `aten/src/ATen/native/mkldnn/IDeepRegistration.cpp`. Thereafter all tensors derived from the `cpu_engine` (by default) will use the c10 allocator. ``` RegisterEngineAllocator cpu_alloc( ideep::engine::cpu_engine(), [](size_t size) { return c10::GetAllocator(c10::DeviceType::CPU)->raw_allocate(size); }, [](void* p) { c10::GetAllocator(c10::DeviceType::CPU)->raw_deallocate(p); } ); ``` ------ 2. Simplify group convolution We had such a scenario in convolution where ideep tensor shape mismatched aten tensor: when `groups > 1`, DNNL expects weights tensors to be 5-d with an extra group dimension, e.g. `goihw` instead of `oihw` in 2d conv case. As shown below, a lot of extra checks came with this difference in shape before. Now we've completely hidden this difference in ideep and all tensors are going to align with pytorch's definition. So we could safely remove these checks from both aten and c2 integration code. ``` // aten/src/ATen/native/mkldnn/Conv.cpp if (w.ndims() == x.ndims() + 1) { AT_ASSERTM( groups > 1, "Only group _mkldnn_conv2d weights could have been reordered to 5d"); kernel_size[0] = w.get_dim(0) * w.get_dim(1); std::copy_n( w.get_dims().cbegin() + 2, x.ndims() - 1, kernel_size.begin() + 1); } else { std::copy_n(w.get_dims().cbegin(), x.ndims(), kernel_size.begin()); } ``` ------ 3. Enable DNNL built-in cache Previously, we stored DNNL jitted kernels along with intermediate buffers inside ideep using an LRU cache. Now we are switching to the newly added DNNL built-in cache, and **no longer** caching buffers in order to reduce memory footprint. This change will be mainly reflected in lower memory usage from memory profiling results. On the code side, we removed couple of lines of `op_key_` that depended on the ideep cache before. ------ 4. Use 64-bit integer to denote dimensions We changed the type of `ideep::dims` from `vector<int32_t>` to `vector<int64_t>`. This renders ideep dims no longer compatible with 32-bit dims used by caffe2. So we use something like `{stride_.begin(), stride_.end()}` to cast parameter `stride_` into a int64 vector. <br> **Misc changes in each commit:** **Commit:** change build options Some build options were slightly changed, mainly to avoid name collisions with other projects that include DNNL as a subproject. In addition, DNNL built-in cache is enabled by option `DNNL_ENABLE_PRIMITIVE_CACHE`. Old | New -- | -- WITH_EXAMPLE | MKLDNN_BUILD_EXAMPLES WITH_TEST | MKLDNN_BUILD_TESTS MKLDNN_THREADING | MKLDNN_CPU_RUNTIME MKLDNN_USE_MKL | N/A (not use MKL anymore) ------ **Commit:** aten reintegration - aten/src/ATen/native/mkldnn/BinaryOps.cpp Implement binary ops using new operation `binary` provided by DNNL - aten/src/ATen/native/mkldnn/Conv.cpp Clean up group convolution checks Simplify conv backward integration - aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp Simplify prepacking convolution weights - test/test_mkldnn.py Fixed an issue in conv2d unit test: it didn't check conv results between mkldnn and aten implementation before. Instead, it compared the mkldnn with mkldnn as the default cpu path will also go into mkldnn. Now we use `torch.backends.mkldnn.flags` to fix this issue - torch/utils/mkldnn.py Prepack weight tensor on module `__init__` to achieve better performance significantly ------ **Commit:** caffe2 reintegration - caffe2/ideep/ideep_utils.h Clean up unused type definitions - caffe2/ideep/operators/adam_op.cc & caffe2/ideep/operators/momentum_sgd_op.cc Unify tensor initialization with `ideep::tensor::init`. Obsolete `ideep::tensor::reinit` - caffe2/ideep/operators/conv_op.cc & caffe2/ideep/operators/quantization/int8_conv_op.cc Clean up group convolution checks Revamp convolution API - caffe2/ideep/operators/conv_transpose_op.cc Clean up group convolution checks Clean up deconv workaround code ------ **Commit:** custom allocator - Register c10 allocator as mentioned above <br><br> ## Performance We tested inference on some common models based on user scenarios, and most performance numbers are either better than or on par with DNNL 0.20. ratio: new / old | Latency (batch=1 4T) | Throughput (batch=64 56T) -- | -- | -- pytorch resnet18 | 121.4% | 99.7% pytorch resnet50 | 123.1% | 106.9% pytorch resnext101_32x8d | 116.3% | 100.1% pytorch resnext50_32x4d | 141.9% | 104.4% pytorch mobilenet_v2 | 163.0% | 105.8% caffe2 alexnet | 303.0% | 99.2% caffe2 googlenet-v3 | 101.1% | 99.2% caffe2 inception-v1 | 102.2% | 101.7% caffe2 mobilenet-v1 | 356.1% | 253.7% caffe2 resnet101 | 100.4% | 99.8% caffe2 resnet152 | 99.8% | 99.8% caffe2 shufflenet | 141.1% | 69.0% † caffe2 squeezenet | 98.5% | 99.2% caffe2 vgg16 | 136.8% | 100.6% caffe2 googlenet-v3 int8 | 100.0% | 100.7% caffe2 mobilenet-v1 int8 | 779.2% | 943.0% caffe2 resnet50 int8 | 99.5% | 95.5% _Configuration: Platform: Skylake 8180 Latency Test: 4 threads, warmup 30, iteration 500, batch size 1 Throughput Test: 56 threads, warmup 30, iteration 200, batch size 64_ † Shufflenet is one of the few models that require temp buffers during inference. The performance degradation is an expected issue since we no longer cache any buffer in the ideep. As for the solution, we suggest users opt for caching allocator like **jemalloc** as a drop-in replacement for system allocator in such heavy workloads. Pull Request resolved: https://github.com/pytorch/pytorch/pull/32422 Test Plan: Perf results: https://our.intern.facebook.com/intern/fblearner/details/177790608?tab=Experiment%20Results 10% improvement for ResNext with avx512, neutral on avx2 More results: https://fb.quip.com/ob10AL0bCDXW#NNNACAUoHJP Reviewed By: yinghai Differential Revision: D20381325 Pulled By: dzhulgakov fbshipit-source-id: 803b906fd89ed8b723c5fcab55039efe3e4bcb77
1002 lines
31 KiB
C++
1002 lines
31 KiB
C++
#include "caffe2/opt/optimize_ideep.h"
|
|
#include "caffe2/opt/converter.h"
|
|
|
|
#ifdef CAFFE2_USE_MKLDNN
|
|
#include <cpuinfo.h>
|
|
#include "caffe2/ideep/ideep_utils.h"
|
|
#endif
|
|
|
|
namespace caffe2 {
|
|
namespace opt {
|
|
|
|
using namespace nom;
|
|
|
|
#ifndef CAFFE2_USE_MKLDNN
|
|
void OptimizeForMkldnn(
|
|
repr::NNModule* nn,
|
|
caffe2::Workspace* ws,
|
|
bool training_mode) {
|
|
LOG(WARNING) << "Only support optimizations for IDEEP";
|
|
}
|
|
|
|
#else
|
|
USE_IDEEP_DEF_ALIASES();
|
|
|
|
Blob* getBlob(const std::string name, caffe2::Workspace* ws) {
|
|
CAFFE_ENFORCE(ws->HasBlob(name), "Blob ", name, " not in workspace");
|
|
return ws->GetBlob(name);
|
|
}
|
|
|
|
Blob* getBlob(repr::NNGraph::NodeRef node, caffe2::Workspace* ws) {
|
|
auto tensor = repr::nn::get<repr::Tensor>(node);
|
|
return getBlob(tensor->getName(), ws);
|
|
}
|
|
|
|
template <class T>
|
|
T getTensor(Blob* blob) {
|
|
CAFFE_ENFORCE(blob, "Blob is invalid");
|
|
return blob->template Get<T>();
|
|
}
|
|
|
|
template <class T>
|
|
T* getMutableTensor(Blob* blob) {
|
|
CAFFE_ENFORCE(blob, "Blob is invalid");
|
|
if (blob && blob->template IsType<T>()) {
|
|
return blob->template GetMutable<T>();
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
const caffe2::OperatorDef& getOpDef(const repr::NeuralNetOperator& nnOp) {
|
|
auto annotation = nnOp.getAnnotation();
|
|
if (annotation == nullptr) {
|
|
CAFFE_THROW("Cannot get Operator annotation");
|
|
}
|
|
return dyn_cast<Caffe2Annotation>(annotation)->getOperatorDef();
|
|
}
|
|
|
|
caffe2::OperatorDef* getMutableOpDef(repr::NeuralNetOperator& nnOp) {
|
|
auto annotation = nnOp.getMutableAnnotation();
|
|
if (annotation == nullptr) {
|
|
CAFFE_THROW("Cannot get Operator annotation");
|
|
}
|
|
return dyn_cast<Caffe2Annotation>(annotation)->getMutableOperatorDef();
|
|
}
|
|
|
|
bool isOpType(const repr::NNGraph::NodeRef& nodeRef, string typeName) {
|
|
if (!repr::nn::is<repr::NeuralNetOperator>(nodeRef)) {
|
|
return false;
|
|
}
|
|
auto op = repr::nn::get<repr::NeuralNetOperator>(nodeRef);
|
|
auto opDef = getOpDef(*op);
|
|
return opDef.type() == typeName;
|
|
}
|
|
|
|
bool isOnIdeepDevice(const repr::NeuralNetOperator& nnOp) {
|
|
// We only want to fuse for IDEEP operators
|
|
const auto& op = getOpDef(nnOp);
|
|
return op.device_option().device_type() == DeviceTypeProto::PROTO_IDEEP;
|
|
}
|
|
|
|
bool isConvFusion(repr::NNGraph::NodeRef convNode, int fusion_type) {
|
|
// Here we only check the type of ConvFusion op (for FP32 only)
|
|
if (!repr::nn::is<repr::Conv>(convNode)) {
|
|
return false;
|
|
}
|
|
|
|
auto conv = repr::nn::get<repr::Conv>(convNode);
|
|
auto& op = getOpDef(*conv);
|
|
|
|
if (op.type() == "ConvFusion") {
|
|
for (const auto& arg : op.arg()) {
|
|
if (arg.name() == "fusion_type") {
|
|
if (fusion_type == FUSION_MAX) {
|
|
return true;
|
|
}
|
|
return arg.i() == fusion_type;
|
|
}
|
|
}
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
void resetConvForFusion(repr::NNGraph::NodeRef convNode, int fusion_type) {
|
|
auto conv = repr::nn::get<repr::Conv>(convNode);
|
|
auto* op = getMutableOpDef(*conv);
|
|
|
|
if (op == nullptr) {
|
|
return;
|
|
}
|
|
|
|
if (op->type() == "ConvFusion") {
|
|
CAFFE_ENFORCE(fusion_type == FUSION_CONV_RELU, "Invalid nest fusion");
|
|
for (auto& arg : *op->mutable_arg()) {
|
|
if (arg.name() == "fusion_type") {
|
|
CAFFE_ENFORCE(arg.i() == FUSION_CONV_SUM, "Invalid nest fusion");
|
|
// Only from FUSION_CONV_SUM to FUSION_CONV_SUM_RELU
|
|
arg.set_i(FUSION_CONV_SUM_RELU);
|
|
return;
|
|
}
|
|
}
|
|
CAFFE_THROW("Can not find fusion type in ConvFusion");
|
|
}
|
|
|
|
CAFFE_ENFORCE_LT(fusion_type, FUSION_CONV_SUM_RELU, "Invalid fusion type");
|
|
op->set_type("ConvFusion");
|
|
auto* arg = op->add_arg();
|
|
arg->set_name("fusion_type");
|
|
arg->set_i(fusion_type);
|
|
}
|
|
|
|
void removeArg(repr::NeuralNetOperator& nnOp, std::string argName) {
|
|
auto* op = getMutableOpDef(nnOp);
|
|
auto& opArgs = *op->mutable_arg();
|
|
auto remove_arg = [](decltype(opArgs)& args, std::string& name) {
|
|
for (auto it = args.begin(); it != args.end(); it++) {
|
|
if (it->name() == name) {
|
|
args.erase(it);
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
};
|
|
while (remove_arg(opArgs, argName))
|
|
;
|
|
}
|
|
|
|
void moveOpArg(
|
|
caffe2::Workspace* ws,
|
|
std::string argName,
|
|
repr::NeuralNetOperator* srcOp,
|
|
repr::NeuralNetOperator* dstOp) {
|
|
if (argName.empty() || srcOp == nullptr || dstOp == nullptr || srcOp == dstOp)
|
|
return;
|
|
removeArg(*dstOp, argName);
|
|
|
|
auto& src = getOpDef(*srcOp);
|
|
auto& src_args = src.arg();
|
|
auto src_it = src_args.begin();
|
|
for (; src_it != src_args.end(); src_it++) {
|
|
if (src_it->name() == argName)
|
|
break;
|
|
}
|
|
if (src_it == src_args.end())
|
|
return;
|
|
|
|
auto* dst = getMutableOpDef(*dstOp);
|
|
auto* arg = dst->add_arg();
|
|
*arg = *src_it;
|
|
arg->set_name(argName);
|
|
}
|
|
|
|
bool removeStopGradientForInference(repr::NNModule* nn, caffe2::Workspace* ws) {
|
|
auto allNodes = nn->dataFlow.getMutableNodes();
|
|
for (int i = 0; i < allNodes.size(); ++i) {
|
|
auto node = allNodes[i];
|
|
if (!isOpType(node, "StopGradient")) {
|
|
continue;
|
|
}
|
|
|
|
auto stopGradInput = repr::nn::getInputs(node).front();
|
|
auto stopGradOutput = repr::nn::getOutputs(node).front();
|
|
auto inputName = repr::nn::get<repr::Tensor>(stopGradInput)->getName();
|
|
auto outputName = repr::nn::get<repr::Tensor>(stopGradOutput)->getName();
|
|
if (inputName == outputName) {
|
|
nn->dataFlow.replaceNode(stopGradOutput, stopGradInput);
|
|
nn->dataFlow.deleteNode(node);
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool fuseConvBNAndAffCh(repr::NNModule* nn, caffe2::Workspace* ws) {
|
|
for (auto node_pair : repr::nn::dataIterator<repr::Conv>(nn->dataFlow)) {
|
|
bool no_bias = false;
|
|
repr::NNGraph::NodeRef convNode;
|
|
repr::Conv* conv;
|
|
std::tie(conv, convNode) = node_pair;
|
|
|
|
if (!isOnIdeepDevice(*conv)) {
|
|
LOG(WARNING) << "Not a IDEEP operator";
|
|
continue;
|
|
}
|
|
|
|
const auto& convOp = getOpDef(*conv);
|
|
if (convOp.type() == "ConvFusion") {
|
|
continue;
|
|
}
|
|
|
|
auto convOutput = repr::nn::getOutputs(convNode).front();
|
|
auto consumers = repr::nn::getConsumers(convOutput);
|
|
// convOutput is NOT referenced by sequential ops after BN.
|
|
if (consumers.size() != 1) {
|
|
continue;
|
|
}
|
|
|
|
bool isBN;
|
|
auto consumer = consumers.front();
|
|
if (repr::nn::is<repr::BatchNormalization>(consumer)) {
|
|
isBN = true;
|
|
} else if (isOpType(consumer, "AffineChannel")) {
|
|
isBN = false;
|
|
} else {
|
|
continue;
|
|
}
|
|
|
|
auto bnOrAffChNode = consumer;
|
|
auto bn =
|
|
isBN ? repr::nn::get<repr::BatchNormalization>(bnOrAffChNode) : nullptr;
|
|
auto bnOrAffChOutput = repr::nn::getOutputs(bnOrAffChNode).front();
|
|
|
|
auto convInputs = repr::nn::getInputs(convNode);
|
|
if (convInputs.size() < 2) {
|
|
LOG(WARNING) << "Invalid convolution input size";
|
|
continue;
|
|
}
|
|
|
|
auto bnOrAffChInputs = repr::nn::getInputs(bnOrAffChNode);
|
|
int numInputs = isBN ? 5 : 3;
|
|
if (bnOrAffChInputs.size() < numInputs) {
|
|
LOG(WARNING) << "Invalid input size: " << bnOrAffChInputs.size()
|
|
<< ", expect " << numInputs;
|
|
continue;
|
|
}
|
|
|
|
// When no bias, borrow BN bias
|
|
if (convInputs.size() < 3) {
|
|
no_bias = true;
|
|
nn->dataFlow.createEdge(bnOrAffChInputs[2], convNode);
|
|
convInputs = repr::nn::getInputs(convNode);
|
|
}
|
|
|
|
#define EXPOSE_TENSOR_DATA(name, index, nodes, need_init) \
|
|
itensor* name = nullptr; \
|
|
itensor name##Tensor; \
|
|
float* name##Data = nullptr; \
|
|
if (need_init) { \
|
|
name = getMutableTensor<itensor>(getBlob(nodes[index], ws)); \
|
|
if (name == nullptr) { \
|
|
LOG(WARNING) << #name " not a IDEEP tensor"; \
|
|
continue; \
|
|
} \
|
|
name##Tensor.resize(name->get_dims(), name->get_data_type()); \
|
|
name##Tensor.feed_from(*name); \
|
|
CAFFE_ENFORCE( \
|
|
name##Tensor.is_public_format(), #name " not with public format"); \
|
|
name##Data = static_cast<float*>(name##Tensor.get_data_handle()); \
|
|
}
|
|
|
|
EXPOSE_TENSOR_DATA(filter, 1, convInputs, true);
|
|
EXPOSE_TENSOR_DATA(biasConv, 2, convInputs, true);
|
|
|
|
EXPOSE_TENSOR_DATA(scale, 1, bnOrAffChInputs, true);
|
|
EXPOSE_TENSOR_DATA(biasBNOrAffCh, 2, bnOrAffChInputs, true);
|
|
EXPOSE_TENSOR_DATA(mean, 3, bnOrAffChInputs, isBN);
|
|
EXPOSE_TENSOR_DATA(variance, 4, bnOrAffChInputs, isBN);
|
|
|
|
#undef EXPOSE_TENSOR_DATA
|
|
|
|
// Assume M{CHW,HWC}
|
|
auto chwDim = filterTensor.get_dim(1) * filterTensor.get_dim(2) *
|
|
filterTensor.get_dim(3);
|
|
for (auto c = 0; c < filterTensor.get_dim(0); ++c) {
|
|
float mean_val = 0;
|
|
float variance_val = 1;
|
|
if (isBN) {
|
|
mean_val = meanData[c];
|
|
variance_val = std::sqrt(varianceData[c] + bn->getEpsilon());
|
|
}
|
|
float coeff = scaleData[c] / variance_val;
|
|
for (auto i = 0; i < chwDim; ++i) {
|
|
filterData[c * chwDim + i] *= coeff;
|
|
}
|
|
|
|
if (no_bias) {
|
|
biasConvData[c] = biasBNOrAffChData[c] - mean_val * coeff;
|
|
} else {
|
|
biasConvData[c] =
|
|
biasBNOrAffChData[c] + (biasConvData[c] - mean_val) * coeff;
|
|
}
|
|
}
|
|
|
|
filter->feed_from(filterTensor);
|
|
biasConv->feed_from(biasConvTensor);
|
|
nn->dataFlow.replaceNode(convOutput, bnOrAffChOutput);
|
|
|
|
nn->dataFlow.deleteNode(bnOrAffChNode);
|
|
nn->dataFlow.deleteNode(convOutput);
|
|
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool fuseConvSum(repr::NNModule* nn, caffe2::Workspace* ws) {
|
|
CAFFE_ENFORCE(cpuinfo_initialize(), "failed to initialize cpuinfo");
|
|
// Assume the order of nodes from getMutableNodes conforms to
|
|
// the original topo order of operators
|
|
auto allNodes = nn->dataFlow.getMutableNodes();
|
|
for (int i = allNodes.size() - 1; i > 0; i--) {
|
|
auto sumNode = allNodes[i];
|
|
if (!repr::nn::hasInputs(sumNode)) {
|
|
continue;
|
|
}
|
|
|
|
// [Caution] on IDEEP device, only element-wise Add operator is
|
|
// supported yet. It totally works as element-wise sum without scalar
|
|
// broadcast.
|
|
bool is_dnnlowp_sum = false;
|
|
if (isOpType(sumNode, "Int8Sum") || isOpType(sumNode, "Int8Add") ||
|
|
isOpType(sumNode, "Int8SumRelu") || isOpType(sumNode, "Int8AddRelu")) {
|
|
is_dnnlowp_sum = true;
|
|
} else if (!repr::nn::is<repr::Sum>(sumNode) && !isOpType(sumNode, "Add")) {
|
|
continue;
|
|
}
|
|
|
|
auto sum = repr::nn::get<repr::NeuralNetOperator>(sumNode);
|
|
if (!isOnIdeepDevice(*sum)) {
|
|
LOG(WARNING) << "Not a IDEEP operator";
|
|
continue;
|
|
}
|
|
|
|
auto sumInputs = repr::nn::getInputs(sumNode);
|
|
if (sumInputs.size() != 2) {
|
|
continue;
|
|
}
|
|
|
|
int sum_idx = i;
|
|
repr::NNGraph::NodeRef convNode = nullptr;
|
|
while (--i >= 0) {
|
|
if (repr::nn::is<repr::NeuralNetOperator>(allNodes[i])) {
|
|
// Find the nearest conv Op before Sum
|
|
if (repr::nn::is<repr::Conv>(allNodes[i]) ||
|
|
isOpType(allNodes[i], "Int8Conv")) {
|
|
convNode = allNodes[i];
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
if (convNode == nullptr || isConvFusion(convNode, FUSION_MAX)) {
|
|
continue;
|
|
}
|
|
int conv_idx = i;
|
|
|
|
auto conv = repr::nn::get<repr::NeuralNetOperator>(convNode);
|
|
if (!isOnIdeepDevice(*conv)) {
|
|
LOG(WARNING) << "Not a IDEEP operator";
|
|
continue;
|
|
}
|
|
|
|
auto group = 1;
|
|
auto* convOp = getMutableOpDef(*conv);
|
|
for (const auto& arg : convOp->arg()) {
|
|
if (arg.name() == "group") {
|
|
group = arg.i();
|
|
break;
|
|
}
|
|
}
|
|
if (group > 1 && !cpuinfo_has_x86_avx512f()) {
|
|
LOG(WARNING) << "Not support conv sum fusion with grouped filter";
|
|
continue;
|
|
}
|
|
|
|
auto convOutput = repr::nn::getOutputs(convNode).front();
|
|
if (convOutput != sumInputs[0] && convOutput != sumInputs[1]) {
|
|
continue;
|
|
}
|
|
repr::NNGraph::NodeRef sumInputX =
|
|
(sumInputs[0] == convOutput ? sumInputs[1] : sumInputs[0]);
|
|
CAFFE_ENFORCE(sumInputX != nullptr, "Invalid sum inputs");
|
|
if (sumInputX->getInEdges().size() <= 0) {
|
|
continue;
|
|
}
|
|
|
|
auto preNode = repr::nn::getProducer(sumInputX);
|
|
if (preNode == nullptr || !repr::nn::is<repr::NeuralNetOperator>(preNode)) {
|
|
LOG(WARNING) << "Can not fuse Conv Sum";
|
|
continue;
|
|
}
|
|
int pre_idx = sum_idx - 1;
|
|
while (pre_idx >= 0) {
|
|
if (preNode == allNodes[pre_idx]) {
|
|
break;
|
|
}
|
|
pre_idx--;
|
|
}
|
|
|
|
bool should_fuse = true;
|
|
auto convInput = repr::nn::getInputs(convNode).front();
|
|
for (int idx = conv_idx + 1; idx < allNodes.size() - 1; ++idx) {
|
|
if (idx == sum_idx ||
|
|
!repr::nn::is<repr::NeuralNetOperator>(allNodes[idx])) {
|
|
continue;
|
|
}
|
|
|
|
auto checkNode = allNodes[idx];
|
|
auto checkInputs = repr::nn::getInputs(checkNode);
|
|
// Conv output should not be used by other ops after Conv node (except the
|
|
// fused Sum) The other Sum input (sumInputX) should not be used by the
|
|
// other ops after Sum node due to the Sum output is inplace with
|
|
// sumInputX
|
|
for (size_t input_idx = 0; input_idx < checkInputs.size(); ++input_idx) {
|
|
if (convOutput == checkInputs[input_idx] ||
|
|
(idx > sum_idx && sumInputX == checkInputs[input_idx])) {
|
|
should_fuse = false;
|
|
break;
|
|
}
|
|
}
|
|
if (!should_fuse) {
|
|
break;
|
|
}
|
|
|
|
// If fuse Conv with Sum, the Conv op will be pulled down between preNode
|
|
// and Sum Check Conv input tensor buffer has been re-written by other ops
|
|
// between Conv and preNode
|
|
if (idx <= pre_idx) {
|
|
auto checkOutputs = repr::nn::getOutputs(checkNode);
|
|
for (size_t output_idx = 0; output_idx < checkOutputs.size();
|
|
++output_idx) {
|
|
auto check_output_tensor =
|
|
repr::nn::get<repr::Tensor>(checkOutputs[output_idx]);
|
|
auto conv_input_tensor = repr::nn::get<repr::Tensor>(convInput);
|
|
if (conv_input_tensor->getName() == check_output_tensor->getName()) {
|
|
should_fuse = false;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
if (!should_fuse) {
|
|
break;
|
|
}
|
|
}
|
|
if (!should_fuse) {
|
|
continue;
|
|
}
|
|
|
|
nn->dataFlow.createEdge(sumInputX, convNode);
|
|
auto newOutputName = repr::nn::get<repr::Tensor>(sumInputX)->getName() +
|
|
"_fusion_fix_" + std::to_string(i);
|
|
|
|
auto newInputTensor = std::make_unique<repr::Tensor>(newOutputName);
|
|
auto newInput = nn->dataFlow.createNode(
|
|
unique_dyn_cast<repr::NeuralNetData>(newInputTensor));
|
|
|
|
nn->dataFlow.replaceNode(sumInputX, newInput);
|
|
nn->dataFlow.deleteNode(sumInputX);
|
|
|
|
auto newOutputTensor = std::make_unique<repr::Tensor>(newOutputName);
|
|
auto newOutput = nn->dataFlow.createNode(
|
|
unique_dyn_cast<repr::NeuralNetData>(newOutputTensor));
|
|
|
|
auto sumOutput = repr::nn::getOutputs(sumNode).front();
|
|
nn->dataFlow.replaceNode(sumOutput, newOutput);
|
|
nn->dataFlow.createEdge(convNode, newOutput);
|
|
|
|
if (!is_dnnlowp_sum) {
|
|
resetConvForFusion(convNode, FUSION_CONV_SUM);
|
|
} else {
|
|
moveOpArg(ws, "Y_scale", sum, conv);
|
|
moveOpArg(ws, "Y_zero_point", sum, conv);
|
|
|
|
if (isOpType(sumNode, "Int8Sum") || isOpType(sumNode, "Int8Add")) {
|
|
convOp->set_type("Int8ConvSum");
|
|
} else if (
|
|
isOpType(sumNode, "Int8SumRelu") ||
|
|
isOpType(sumNode, "Int8AddRelu")) {
|
|
convOp->set_type("Int8ConvSumRelu");
|
|
} else {
|
|
CAFFE_THROW("Unsupport operator in conv fusion");
|
|
}
|
|
}
|
|
|
|
nn->dataFlow.deleteNode(sumNode);
|
|
nn->dataFlow.deleteNode(sumOutput);
|
|
nn->dataFlow.deleteNode(convOutput);
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool fuseActivation(repr::NNModule* nn, caffe2::Workspace* ws) {
|
|
// Conv+Relu fusion
|
|
for (auto node_pair : repr::nn::dataIterator<repr::Conv>(nn->dataFlow)) {
|
|
repr::NNGraph::NodeRef conv_node;
|
|
repr::Conv* conv;
|
|
std::tie(conv, conv_node) = node_pair;
|
|
|
|
// Check topological feasibility
|
|
auto conv_outputs = repr::nn::getOutputs(conv_node);
|
|
if (conv_outputs.size() != 1) {
|
|
continue;
|
|
}
|
|
auto conv_output = conv_outputs.front();
|
|
|
|
auto consumers = repr::nn::getConsumers(conv_output);
|
|
if (consumers.size() != 1) {
|
|
continue;
|
|
}
|
|
if (!repr::nn::is<repr::Relu>(consumers.front())) {
|
|
continue;
|
|
}
|
|
auto relu_node = consumers.front();
|
|
auto relu = repr::nn::get<repr::Relu>(relu_node);
|
|
|
|
auto relu_outputs = repr::nn::getOutputs(relu_node);
|
|
if (relu_outputs.size() != 1) {
|
|
continue;
|
|
}
|
|
|
|
// Check feasibility with application specific logic
|
|
if (!isOnIdeepDevice(*conv)) {
|
|
continue;
|
|
}
|
|
|
|
// Ready to fuse
|
|
auto relu_output = relu_outputs.front();
|
|
auto output_tensor = repr::nn::get<repr::Tensor>(relu_output);
|
|
auto output_node = relu_output;
|
|
auto input_tensor =
|
|
repr::nn::get<repr::Tensor>(repr::nn::getInputs(conv_node).front());
|
|
|
|
if (isConvFusion(conv_node, FUSION_CONV_SUM)) {
|
|
nn->dataFlow.replaceNode(relu_output, conv_output);
|
|
nn->dataFlow.deleteNode(relu_node);
|
|
nn->dataFlow.deleteNode(relu_output);
|
|
} else {
|
|
// Conv cannot be in-place
|
|
if (output_tensor->getName() != input_tensor->getName()) {
|
|
nn->dataFlow.replaceNode(conv_output, relu_output);
|
|
nn->dataFlow.deleteNode(relu_node);
|
|
nn->dataFlow.deleteNode(conv_output);
|
|
} else {
|
|
nn->dataFlow.replaceNode(relu_output, conv_output);
|
|
output_tensor = repr::nn::get<repr::Tensor>(conv_output);
|
|
output_node = conv_output;
|
|
nn->dataFlow.deleteNode(relu_node);
|
|
nn->dataFlow.deleteNode(relu_output);
|
|
}
|
|
|
|
// We may have accidentally made the next op in-place
|
|
// In future iterations of transformations this won't be an issue,
|
|
// but current caffe2 predictor usage requires things like
|
|
// external_input and output to be unchanged.
|
|
bool rectify_inplace = false;
|
|
for (auto& consumer : repr::nn::getConsumers(output_node)) {
|
|
for (auto& consumer_output : repr::nn::getOutputs(consumer)) {
|
|
auto co_name =
|
|
repr::nn::get<repr::Tensor>(consumer_output)->getName();
|
|
if (co_name == output_tensor->getName()) {
|
|
rectify_inplace = true;
|
|
}
|
|
}
|
|
}
|
|
if (rectify_inplace) {
|
|
auto new_output = nn->dataFlow.createNode(make_unique<repr::Tensor>(
|
|
output_tensor->getName() + "_fusion_fix"));
|
|
nn->dataFlow.replaceNode(output_node, new_output);
|
|
}
|
|
}
|
|
|
|
resetConvForFusion(conv_node, FUSION_CONV_RELU);
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool enforceFusionInplace(repr::NNModule* nn, caffe2::Workspace* ws) {
|
|
// For fusions of Conv+Sum or Conv+Sum+ReLU, the last input and output must
|
|
// be inplaced. To enforce inplace, here to re-check whole graph and correct
|
|
// the ConvFusion Ops.
|
|
auto allNodes = nn->dataFlow.getMutableNodes();
|
|
for (int i = allNodes.size() - 1; i > 0; i--) {
|
|
auto convNode = allNodes[i];
|
|
if (convNode == nullptr ||
|
|
!repr::nn::is<repr::NeuralNetOperator>(convNode)) {
|
|
continue;
|
|
}
|
|
|
|
auto conv = repr::nn::get<repr::NeuralNetOperator>(convNode);
|
|
if (!isOnIdeepDevice(*conv)) {
|
|
LOG(WARNING) << "Not a IDEEP operator";
|
|
continue;
|
|
}
|
|
|
|
if (repr::nn::is<repr::Conv>(convNode)) {
|
|
if (!isConvFusion(convNode, FUSION_CONV_SUM) &&
|
|
!isConvFusion(convNode, FUSION_CONV_SUM_RELU))
|
|
continue;
|
|
} else if (
|
|
!isOpType(convNode, "Int8ConvSum") &&
|
|
!isOpType(convNode, "Int8ConvSumRelu")) {
|
|
continue;
|
|
}
|
|
|
|
auto convInput = repr::nn::getInputs(convNode).back();
|
|
auto inputName = repr::nn::get<repr::Tensor>(convInput)->getName();
|
|
auto convOutput = repr::nn::getOutputs(convNode).front();
|
|
auto outputName = repr::nn::get<repr::Tensor>(convOutput)->getName();
|
|
if (inputName == outputName) {
|
|
continue;
|
|
}
|
|
|
|
auto consumer = repr::nn::getConsumers(convInput).back();
|
|
if (consumer != convNode) {
|
|
LOG(ERROR) << "Can not enforce to inplace for fusion";
|
|
return false;
|
|
}
|
|
|
|
auto newOutputTensor = std::make_unique<repr::Tensor>(inputName);
|
|
auto newOutput = nn->dataFlow.createNode(
|
|
unique_dyn_cast<repr::NeuralNetData>(newOutputTensor));
|
|
nn->dataFlow.replaceNode(convOutput, newOutput);
|
|
nn->dataFlow.deleteNode(convOutput);
|
|
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool fuseOrderSwitchToQuantizeOp(repr::NNModule* nn, caffe2::Workspace* ws) {
|
|
// In INT8 module, the quantize/dequantize op always appears
|
|
// along with corresponding order switch op, which aims to switch
|
|
// between INT8 computation domain and others.
|
|
// Here we assume they always obey below combination and order:
|
|
// NCHW2NHWC followed by Int8Quantize, or Int8Dequantize followed by NHWC2NCHW
|
|
// On iDEEP, there is chance to fuse the order switch op into the
|
|
// quantize/dequantize op, in order to improve the module performance.
|
|
auto allNodes = nn->dataFlow.getMutableNodes();
|
|
for (int i = 0; i < allNodes.size(); ++i) {
|
|
auto osNode = allNodes[i];
|
|
if (osNode == nullptr || !repr::nn::is<repr::NeuralNetOperator>(osNode)) {
|
|
continue;
|
|
}
|
|
|
|
if (isOpType(osNode, "NCHW2NHWC")) {
|
|
auto output = repr::nn::getOutputs(osNode).front();
|
|
auto consumers = repr::nn::getConsumers(output);
|
|
if (consumers.size() != 1) {
|
|
continue;
|
|
}
|
|
|
|
auto seqNode = consumers.front();
|
|
if (!isOpType(seqNode, "Int8Quantize")) {
|
|
continue;
|
|
}
|
|
|
|
auto seq = repr::nn::get<repr::NeuralNetOperator>(seqNode);
|
|
removeArg(*seq, "output_order");
|
|
|
|
auto* seqOp = getMutableOpDef(*seq);
|
|
auto* arg = seqOp->add_arg();
|
|
arg->set_name("output_order");
|
|
arg->set_i(static_cast<int64_t>(iformat::nhwc));
|
|
|
|
auto input = repr::nn::getInputs(osNode).front();
|
|
nn->dataFlow.replaceNode(output, input);
|
|
|
|
nn->dataFlow.deleteNode(osNode);
|
|
nn->dataFlow.deleteNode(output);
|
|
return true;
|
|
} else if (isOpType(osNode, "NHWC2NCHW")) {
|
|
auto input = repr::nn::getInputs(osNode).front();
|
|
if (input->getInEdges().size() <= 0) {
|
|
continue;
|
|
}
|
|
|
|
auto preNode = repr::nn::getProducer(input);
|
|
if (!isOpType(preNode, "Int8Dequantize")) {
|
|
continue;
|
|
}
|
|
|
|
auto pre = repr::nn::get<repr::NeuralNetOperator>(preNode);
|
|
removeArg(*pre, "output_order");
|
|
|
|
auto* preOp = getMutableOpDef(*pre);
|
|
auto* arg = preOp->add_arg();
|
|
arg->set_name("output_order");
|
|
arg->set_i(static_cast<int64_t>(iformat::nchw));
|
|
|
|
auto output = repr::nn::getOutputs(osNode).front();
|
|
nn->dataFlow.replaceNode(input, output);
|
|
|
|
nn->dataFlow.deleteNode(osNode);
|
|
nn->dataFlow.deleteNode(input);
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool fusePreConvertOp(repr::NNModule* nn, caffe2::Workspace* ws) {
|
|
// 1. Int8Sum has been fallbacked to FP32 in current impl
|
|
// It can handle inputs with diff format and data type
|
|
// 2. FC is able to convert input format and data type by itself
|
|
// 3. The fallback wrapper can handle the conversion of format and data type
|
|
static vector<string> op_list = {
|
|
"FC",
|
|
"Python",
|
|
"Softmax",
|
|
"Sigmoid",
|
|
"RoIAlign",
|
|
"UpsampleNearest",
|
|
"BatchPermutation",
|
|
"Int8Sum",
|
|
"Int8SumRelu",
|
|
};
|
|
|
|
auto allNodes = nn->dataFlow.getMutableNodes();
|
|
for (int i = 0; i < allNodes.size(); ++i) {
|
|
auto opNode = allNodes[i];
|
|
if (opNode == nullptr || !repr::nn::is<repr::NeuralNetOperator>(opNode)) {
|
|
continue;
|
|
}
|
|
|
|
if (!isOpType(opNode, "NCHW2NHWC") && !isOpType(opNode, "NHWC2NCHW") &&
|
|
!isOpType(opNode, "Int8Quantize") &&
|
|
!isOpType(opNode, "Int8Dequantize")) {
|
|
continue;
|
|
}
|
|
|
|
auto op = repr::nn::get<repr::NeuralNetOperator>(opNode);
|
|
if (!isOnIdeepDevice(*op)) {
|
|
LOG(WARNING) << "Not a IDEEP operator";
|
|
continue;
|
|
}
|
|
|
|
auto output = repr::nn::getOutputs(opNode).front();
|
|
auto consumers = repr::nn::getConsumers(output);
|
|
if (consumers.size() != 1) {
|
|
continue;
|
|
}
|
|
|
|
bool is_op_found = false;
|
|
auto seqNode = consumers.front();
|
|
for (int j = 0; j < op_list.size(); j++) {
|
|
if (isOpType(seqNode, op_list[j])) {
|
|
is_op_found = true;
|
|
break;
|
|
}
|
|
}
|
|
if (!is_op_found) {
|
|
continue;
|
|
}
|
|
|
|
auto seqOp = repr::nn::get<repr::NeuralNetOperator>(seqNode);
|
|
if (!isOnIdeepDevice(*seqOp)) {
|
|
LOG(WARNING) << "Not a IDEEP operator";
|
|
continue;
|
|
}
|
|
|
|
auto input = repr::nn::getInputs(opNode).front();
|
|
|
|
if (isOpType(opNode, "Int8Dequantize") &&
|
|
repr::nn::hasSingleOutputAndConsumer(opNode)) {
|
|
auto preNode = repr::nn::getProducer(input);
|
|
if (isOpType(preNode, "Int8FC") &&
|
|
repr::nn::hasSingleOutputAndConsumer(preNode)) {
|
|
auto predOp = repr::nn::get<repr::NeuralNetOperator>(preNode);
|
|
removeArg(*predOp, "Y_scale");
|
|
removeArg(*predOp, "Y_zero_point");
|
|
}
|
|
}
|
|
|
|
nn->dataFlow.replaceNode(output, input);
|
|
|
|
nn->dataFlow.deleteNode(opNode);
|
|
nn->dataFlow.deleteNode(output);
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
void setPoolingInferenceMode(repr::NNModule* nn) {
|
|
auto setTrainingMode = [](repr::NeuralNetOperator& pool) {
|
|
if (!isOnIdeepDevice(pool)) {
|
|
LOG(WARNING) << "Not a IDEEP operator";
|
|
return;
|
|
}
|
|
auto* op = getMutableOpDef(pool);
|
|
bool found_training_mode = false;
|
|
for (auto& arg : *op->mutable_arg()) {
|
|
if (arg.name() == "training_mode") {
|
|
arg.set_i(0);
|
|
found_training_mode = true;
|
|
break;
|
|
}
|
|
}
|
|
if (!found_training_mode) {
|
|
auto* arg = op->add_arg();
|
|
arg->set_name("training_mode");
|
|
arg->set_i(0);
|
|
}
|
|
};
|
|
|
|
auto allNodes = nn->dataFlow.getMutableNodes();
|
|
for (int i = 0; i < allNodes.size(); ++i) {
|
|
auto poolNode = allNodes[i];
|
|
if (poolNode == nullptr ||
|
|
!repr::nn::is<repr::NeuralNetOperator>(poolNode)) {
|
|
continue;
|
|
}
|
|
|
|
if (isOpType(poolNode, "FC") || isOpType(poolNode, "Conv") ||
|
|
isOpType(poolNode, "ConvFusion") || isOpType(poolNode, "MaxPool") ||
|
|
isOpType(poolNode, "AveragePool") || isOpType(poolNode, "Int8FC") ||
|
|
isOpType(poolNode, "Int8Conv") || isOpType(poolNode, "Int8ConvRelu") ||
|
|
isOpType(poolNode, "Int8ConvSum") ||
|
|
isOpType(poolNode, "Int8ConvSumRelu") ||
|
|
isOpType(poolNode, "Int8MaxPool") ||
|
|
isOpType(poolNode, "Int8AveragePool")) {
|
|
auto pool = repr::nn::get<repr::NeuralNetOperator>(poolNode);
|
|
setTrainingMode(*pool);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Pre-convert filters format to expected one here
|
|
// in order to avoid boring conversions during computations
|
|
void preConvertFiltersFormat(repr::NNModule* nn, caffe2::Workspace* ws) {
|
|
for (auto& node : nn->dataFlow.getMutableNodes()) {
|
|
if (!repr::nn::is<repr::ConvTranspose>(node) &&
|
|
!repr::nn::is<repr::Conv>(node) && !repr::nn::is<repr::FC>(node)) {
|
|
continue;
|
|
}
|
|
|
|
auto* nnOp = repr::nn::get<repr::NeuralNetOperator>(node);
|
|
if (!isOnIdeepDevice(*nnOp)) {
|
|
LOG(INFO) << "Not a IDEEP operator";
|
|
continue;
|
|
}
|
|
|
|
auto inputs = repr::nn::getInputs(node);
|
|
if (inputs.size() < 2) {
|
|
LOG(WARNING) << "Invalid input size";
|
|
continue;
|
|
}
|
|
|
|
auto* filterBlob = getBlob(inputs[1], ws);
|
|
auto* filter = getMutableTensor<itensor>(filterBlob);
|
|
if (filter == nullptr) {
|
|
continue;
|
|
}
|
|
|
|
itensor::descriptor expectedDesc;
|
|
if (repr::nn::is<repr::ConvTranspose>(node)) {
|
|
if (filter->get_desc().is_iohw())
|
|
continue;
|
|
auto convTranspose = repr::nn::get<repr::ConvTranspose>(node);
|
|
auto initValue = [](vector<int>& v, vector<int> i) {
|
|
if (v.empty())
|
|
v = i;
|
|
};
|
|
auto strides = convTranspose->getStrides();
|
|
initValue(strides, {1, 1});
|
|
auto pads = convTranspose->getPads();
|
|
initValue(pads, {0, 0, 0, 0});
|
|
auto* op = getMutableOpDef(*convTranspose);
|
|
auto aalgorithm = ialgo::deconvolution_direct;
|
|
auto dataType = filter->get_data_type();
|
|
ideep::tensor::dims filter_dims_mkldnn{filter->get_dim(1),
|
|
filter->get_dim(0),
|
|
filter->get_dim(2),
|
|
filter->get_dim(3)};
|
|
expectedDesc =
|
|
ideep::convolution_transpose_forward::expected_weights_desc(
|
|
filter_dims_mkldnn,
|
|
dataType,
|
|
{strides.begin(), strides.end()},
|
|
{pads[0], pads[1]},
|
|
{pads[2], pads[3]});
|
|
|
|
if (filter->get_descriptor() != expectedDesc) {
|
|
itensor newFilter;
|
|
newFilter.init(expectedDesc);
|
|
newFilter.feed_from(*filter);
|
|
filterBlob->Reset<itensor>(new itensor(std::move(newFilter)));
|
|
}
|
|
} else if (repr::nn::is<repr::Conv>(node)) {
|
|
auto conv = repr::nn::get<repr::Conv>(node);
|
|
auto initValue = [](vector<int>& v, vector<int> i) {
|
|
if (v.empty())
|
|
v = i;
|
|
};
|
|
auto strides = conv->getStrides();
|
|
initValue(strides, {1, 1});
|
|
auto pads = conv->getPads();
|
|
initValue(pads, {0, 0, 0, 0});
|
|
auto dilations = conv->getDilations();
|
|
initValue(dilations, {1, 1});
|
|
|
|
auto* op = getMutableOpDef(*conv);
|
|
auto aalgorithm = ialgo::convolution_direct;
|
|
for (auto& arg : *op->mutable_arg()) {
|
|
if ((arg.name() == "conv_algorithm") &&
|
|
(arg.i() == CONV_ALGORITHM_WINOGRAD)) {
|
|
aalgorithm = ialgo::convolution_winograd;
|
|
}
|
|
}
|
|
|
|
expectedDesc = ideep::convolution_forward::expected_weights_desc(
|
|
filter->get_dims(),
|
|
filter->get_data_type(),
|
|
{strides.begin(), strides.end()},
|
|
{pads[0], pads[1]},
|
|
{pads[2], pads[3]},
|
|
{dilations.begin(), dilations.end()},
|
|
conv->getGroup(),
|
|
aalgorithm);
|
|
|
|
if (filter->get_descriptor() != expectedDesc) {
|
|
itensor newFilter;
|
|
newFilter.init(expectedDesc);
|
|
newFilter.feed_from(*filter);
|
|
filterBlob->Reset<itensor>(new itensor(std::move(newFilter)));
|
|
}
|
|
// convert weights for FC
|
|
} else if (repr::nn::is<repr::FC>(node)) {
|
|
auto fc = repr::nn::get<repr::FC>(node);
|
|
auto axis_w = fc->getAxisW();
|
|
if (axis_w != 1) {
|
|
auto f_dims = filter->get_dims();
|
|
auto f_dim0 = std::accumulate(
|
|
f_dims.begin(),
|
|
f_dims.begin() + axis_w,
|
|
1,
|
|
std::multiplies<itensor::dim_t>());
|
|
auto f_dim1 = std::accumulate(
|
|
f_dims.begin() + axis_w,
|
|
f_dims.end(),
|
|
1,
|
|
std::multiplies<itensor::dim_t>());
|
|
filter->reshape({f_dim0, f_dim1});
|
|
}
|
|
|
|
expectedDesc = ideep::inner_product_forward::expected_weights_desc(
|
|
filter->get_dims());
|
|
|
|
if (filter->get_descriptor() != expectedDesc) {
|
|
itensor newFilter;
|
|
newFilter.init(expectedDesc);
|
|
newFilter.feed_from(*filter);
|
|
filterBlob->Reset<itensor>(new itensor(std::move(newFilter)));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Fusers for ideep to parse the graph and apply operator fusion
|
|
using Fuser = bool (*)(repr::NNModule* nn, caffe2::Workspace* ws);
|
|
static Fuser fusers[] = {
|
|
removeStopGradientForInference,
|
|
fuseConvBNAndAffCh,
|
|
fuseConvSum,
|
|
fuseActivation,
|
|
enforceFusionInplace,
|
|
fuseOrderSwitchToQuantizeOp,
|
|
fusePreConvertOp,
|
|
};
|
|
|
|
void OptimizeForMkldnn(
|
|
repr::NNModule* nn,
|
|
caffe2::Workspace* ws,
|
|
bool training_mode) {
|
|
if (training_mode) {
|
|
preConvertFiltersFormat(nn, ws);
|
|
return;
|
|
}
|
|
|
|
for (auto fuser : fusers) {
|
|
while (fuser(nn, ws)) {
|
|
}
|
|
}
|
|
|
|
setPoolingInferenceMode(nn);
|
|
}
|
|
|
|
#endif // CAFFE2_USE_MKLDNN
|
|
|
|
} // namespace opt
|
|
} // namespace caffe2
|