mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This reverts commit 716b3b893d.
Reverted https://github.com/pytorch/pytorch/pull/103725 on behalf of https://github.com/osalpekar due to Broke caffe2 builds due. More info at [D46920675](https://www.internalfb.com/diff/D46920675) ([comment](https://github.com/pytorch/pytorch/pull/103725#issuecomment-1603129273))
339 lines
12 KiB
C++
339 lines
12 KiB
C++
#include "caffe2/opt/fakefp16_transform.h"
|
|
|
|
#include "caffe2/opt/glow_net_transform.h"
|
|
#include "caffe2/utils/proto_utils.h"
|
|
|
|
C10_DEFINE_bool(
|
|
fake_fp16_conversion_use_fp16_acc,
|
|
false,
|
|
"Whether to enable fp16 accumulation for FC / BatchMatMul for fakefp16 "
|
|
"operators.");
|
|
|
|
C10_DEFINE_bool(
|
|
fake_fp16_conversion_use_nnpi,
|
|
false,
|
|
"Whether to simulate NNPI behavior for fakefp16 operators.");
|
|
|
|
namespace caffe2 {
|
|
namespace opt {
|
|
|
|
std::unordered_map<std::string, std::string> getFakeFp16OpMapping(
|
|
bool use_fp16_acc,
|
|
bool use_nnpi) {
|
|
std::unordered_map<std::string, std::string> fake_fp16_op_conversion_map = {
|
|
{"FC", "Fp16FCAcc32NNPI"},
|
|
{"Int8FC", "Int8FCFakeAcc32NNPI"},
|
|
{"Int8Quantize", "Int8QuantizeNNPI"},
|
|
{"Int8Dequantize", "Int8DequantizeNNPI"},
|
|
{"LayerNorm", "LayerNormFakeFP16NNPI"},
|
|
{"FbFCPacked", "Fp16FCAcc32NNPI"},
|
|
{"Logit", "LogitFakeFp16NNPI"},
|
|
{"SparseLengthsSum", "SparseLengthsSumFakeFP16AccFP16"},
|
|
{"SparseLengthsWeightedSum", "SparseLengthsWeightedSumFakeFP16AccFP16"},
|
|
{"SparseLengthsMean", "SparseLengthsMeanFakeFP16AccFP16"},
|
|
{"SparseLengthsSumFused4BitRowwise",
|
|
"SparseLengthsSumFused4BitRowwiseFakeFP16NNPI"},
|
|
{"SparseLengthsWeightedSumFused4BitRowwise",
|
|
"SparseLengthsWeightedSumFused4BitRowwiseFakeFP16NNPI"},
|
|
{"SparseLengthsSumFused8BitRowwise",
|
|
"SparseLengthsSumFused8BitRowwiseFakeFP16NNPI"},
|
|
{"SparseLengthsWeightedSumFused8BitRowwise",
|
|
"SparseLengthsWeightedSumFused8BitRowwiseFakeFP16NNPI"},
|
|
{"SparseLengthsMeanFused8BitRowwise",
|
|
"SparseLengthsMeanFused8BitRowwiseFakeFP16AccFP16"},
|
|
{"MatMul", "BatchMatMulFP16Acc32Fake"},
|
|
{"BatchMatMul", "BatchMatMulFP16Acc32Fake"},
|
|
{"Sigmoid", "SigmoidFakeFp16"},
|
|
{"SpatialBN", "SpatialBNFakeFp16NNPI"},
|
|
{"Swish", "SwishFakeFp16NNPI"},
|
|
{"Tanh", "TanhFakeFp16"},
|
|
{"Relu", "ReluFakeFp16"},
|
|
{"Add", "AddFakeFp16"},
|
|
{"Sub", "SubFakeFp16"},
|
|
{"Mul", "MulFakeFp16"},
|
|
{"Div", "DivFakeFp16"},
|
|
{"Sum", "SumFakeFp16"},
|
|
{"Sqr", "SqrFakeFp16"},
|
|
{"LengthsSum", "LengthsSumFakeFp16"}};
|
|
if (use_fp16_acc) {
|
|
fake_fp16_op_conversion_map["FC"] = "Fp16FCAcc16NNPI";
|
|
fake_fp16_op_conversion_map["FbFCPacked"] = "Fp16FCAcc16NNPI";
|
|
fake_fp16_op_conversion_map["BatchMatMul"] = "BatchMatMulFP16Acc16Fake";
|
|
fake_fp16_op_conversion_map["MatMul"] = "BatchMatMulFP16Acc16Fake";
|
|
}
|
|
if (use_nnpi) {
|
|
fake_fp16_op_conversion_map["Sigmoid"] = "SigmoidFakeFp16NNPI";
|
|
fake_fp16_op_conversion_map["Tanh"] = "TanhFakeFp16NNPI";
|
|
}
|
|
return fake_fp16_op_conversion_map;
|
|
}
|
|
|
|
std::vector<OperatorDef*> findMutableOperatorByInput(
|
|
NetDef* net,
|
|
const std::string& input) {
|
|
std::vector<OperatorDef*> ops;
|
|
|
|
for (auto& op : *net->mutable_op()) {
|
|
for (const auto& i : op.input()) {
|
|
if (input == i) {
|
|
ops.push_back(&op);
|
|
}
|
|
}
|
|
}
|
|
return ops;
|
|
}
|
|
|
|
void fakeFp16FoldLayerNorm(NetDef* net) {
|
|
for (auto& op : *net->mutable_op()) {
|
|
if (op.type() == "LayerNormFakeFP16NNPI") {
|
|
LOG(INFO) << "Attemping to fuse LayerNormFakeFP16NNPI at "
|
|
<< ArgumentHelper::GetSingleArgument<OperatorDef, int>(
|
|
op, "net_pos", -1);
|
|
if (op.input().size() != 1) {
|
|
LOG(INFO) << "input isn't 1, skipping";
|
|
continue;
|
|
}
|
|
|
|
const std::string& ln_output = op.output(0);
|
|
auto next_ops = findMutableOperatorByInput(net, ln_output);
|
|
|
|
if (next_ops.size() != 1 || next_ops[0]->type() != "MulFakeFp16") {
|
|
LOG(INFO) << "next op isn't MulFakeFp16, skipping";
|
|
continue;
|
|
}
|
|
|
|
auto* mul_op = next_ops[0];
|
|
|
|
auto next_next_ops = findMutableOperatorByInput(net, mul_op->output(0));
|
|
|
|
if (next_next_ops.size() != 1 ||
|
|
next_next_ops[0]->type() != "AddFakeFp16") {
|
|
LOG(INFO) << "next op isn't AddFakeFp16, skipping";
|
|
continue;
|
|
}
|
|
|
|
auto* add_op = next_next_ops[0];
|
|
|
|
*(op.mutable_input()->Add()) = mul_op->input(1);
|
|
*(op.mutable_input()->Add()) = add_op->input(1);
|
|
*op.mutable_output(0) = add_op->output(0);
|
|
|
|
mul_op->set_type("delete_me_optimized_away");
|
|
add_op->set_type("delete_me_optimized_away");
|
|
|
|
LOG(INFO) << "Fused LayerNormFakeFP16NNPI";
|
|
}
|
|
}
|
|
}
|
|
|
|
void fakeFp16FoldLayerNormQuant(NetDef* net) {
|
|
for (auto& op : *net->mutable_op()) {
|
|
if (op.type() == "LayerNormFakeFP16NNPI") {
|
|
auto layernormNetPos = ArgumentHelper::GetSingleArgument<OperatorDef, int>(
|
|
op, "net_pos", -1);
|
|
LOG(INFO) << "Attemping to fuse LayerNormFakeFP16NNPI w Quant at "
|
|
<< layernormNetPos;
|
|
if (op.input().size() != 1) {
|
|
LOG(INFO) << "input isn't 1, is " << op.input().size() << " skipping";
|
|
continue;
|
|
}
|
|
|
|
const std::string& ln_output = op.output(0);
|
|
auto next_ops = findMutableOperatorByInput(net, ln_output);
|
|
|
|
if (next_ops.size() != 1 || next_ops[0]->type() != "Int8QuantizeNNPI") {
|
|
LOG(INFO) << "next op isn't Int8QuantizeNNPI, skipping";
|
|
continue;
|
|
}
|
|
|
|
auto* quantOp = next_ops[0];
|
|
|
|
if (quantOp->output().size() != 1) {
|
|
LOG(INFO) << "more than one output for quant, skipping";
|
|
continue;
|
|
}
|
|
|
|
op.set_type("LayerNormInt8QuantizeFakeNNPI");
|
|
|
|
*op.mutable_output(0) = quantOp->output(0);
|
|
op.add_arg()->CopyFrom(MakeArgument("Y_scale",
|
|
ArgumentHelper::GetSingleArgument<OperatorDef, float>(*quantOp, "Y_scale", -1)));
|
|
op.add_arg()->CopyFrom(MakeArgument("Y_zero_point",
|
|
ArgumentHelper::GetSingleArgument<OperatorDef, int>(*quantOp, "Y_zero_point", -1)));
|
|
|
|
auto quantNetPos = ArgumentHelper::GetSingleArgument<OperatorDef, int>(
|
|
*quantOp, "net_pos", -1);
|
|
|
|
quantOp->set_type("delete_me_optimized_away");
|
|
|
|
LOG(INFO) << "Fused LayerNormFakeFP16NNPI w Quant at " << layernormNetPos << " " << quantNetPos;
|
|
}
|
|
}
|
|
}
|
|
|
|
void fakeFp16FoldSwish(NetDef* net) {
|
|
// find a sequence deq->swish->quant and replace it
|
|
for (auto& op : *net->mutable_op()) {
|
|
if (op.type() == "Int8DequantizeNNPI") {
|
|
auto deq_net_pos = ArgumentHelper::GetSingleArgument<OperatorDef, int>(
|
|
op, "net_pos", -1);
|
|
|
|
LOG(INFO) << "Attempting swish fusion at " << deq_net_pos;
|
|
|
|
if (op.output().size() != 1) {
|
|
LOG(INFO) << "more than one output deq, skipping";
|
|
continue;
|
|
}
|
|
|
|
const std::string& deqOutput = op.output(0);
|
|
auto next_ops = findMutableOperatorByInput(net, deqOutput);
|
|
|
|
if (next_ops.size() != 1 || next_ops[0]->type() != "SwishFakeFp16NNPI") {
|
|
LOG(INFO) << "skipping, next op is " << next_ops[0]->type();
|
|
continue;
|
|
}
|
|
|
|
auto* swishOp = next_ops[0];
|
|
|
|
if (swishOp->output().size() != 1) {
|
|
LOG(INFO) << "more than one output for swish, skipping";
|
|
continue;
|
|
}
|
|
|
|
auto next_next_ops = findMutableOperatorByInput(net, swishOp->output(0));
|
|
|
|
if (next_next_ops.size() != 1 || next_next_ops[0]->type() != "Int8QuantizeNNPI") {
|
|
LOG(INFO) << "skipping, next op isn't quant, is " << next_next_ops[0]->type();
|
|
continue;
|
|
}
|
|
|
|
auto* quantOp = next_next_ops[0];
|
|
|
|
op.set_type("SwishFakeInt8NNPI");
|
|
*op.mutable_output(0) = quantOp->output(0);
|
|
op.add_arg()->CopyFrom(MakeArgument("Y_scale",
|
|
ArgumentHelper::GetSingleArgument<OperatorDef, float>(*quantOp, "Y_scale", -1)));
|
|
op.add_arg()->CopyFrom(MakeArgument("Y_zero_point",
|
|
ArgumentHelper::GetSingleArgument<OperatorDef, int>(*quantOp, "Y_zero_point", -1)));
|
|
|
|
auto swish_net_pos = ArgumentHelper::GetSingleArgument<OperatorDef, int>(
|
|
*swishOp, "net_pos", -1);
|
|
auto quant_net_pos = ArgumentHelper::GetSingleArgument<OperatorDef, int>(
|
|
*quantOp, "net_pos", -1);
|
|
|
|
swishOp->set_type("delete_me_optimized_away");
|
|
quantOp->set_type("delete_me_optimized_away");
|
|
|
|
LOG(INFO) << "Fusing swish at " << deq_net_pos << ", " << swish_net_pos << ", " << quant_net_pos;
|
|
}
|
|
}
|
|
}
|
|
|
|
void fakeFp16FoldTanhQuant(NetDef* net) {
|
|
// find a sequence deq->swish->quant and replace it
|
|
for (auto& op : *net->mutable_op()) {
|
|
if (op.type() == "TanhFakeFp16NNPI") {
|
|
auto tanh_net_pos = ArgumentHelper::GetSingleArgument<OperatorDef, int>(
|
|
op, "net_pos", -1);
|
|
|
|
LOG(INFO) << "Attempting tanh fusion at " << tanh_net_pos;
|
|
|
|
if (op.output().size() != 1) {
|
|
LOG(INFO) << "more than one output for tanh, skipping";
|
|
continue;
|
|
}
|
|
|
|
const std::string& tanhOutput = op.output(0);
|
|
auto next_ops = findMutableOperatorByInput(net, tanhOutput);
|
|
|
|
if (next_ops.size() != 1 || next_ops[0]->type() != "Int8QuantizeNNPI") {
|
|
LOG(INFO) << "skipping, next op is " << next_ops[0]->type();
|
|
continue;
|
|
}
|
|
|
|
auto* quantOp = next_ops[0];
|
|
|
|
if (quantOp->output().size() != 1) {
|
|
LOG(INFO) << "more than one output for quant, skipping";
|
|
continue;
|
|
}
|
|
|
|
op.set_type("TanhQuantFakeFp16NNPI");
|
|
*op.mutable_output(0) = quantOp->output(0);
|
|
op.add_arg()->CopyFrom(MakeArgument("Y_scale",
|
|
ArgumentHelper::GetSingleArgument<OperatorDef, float>(*quantOp, "Y_scale", -1)));
|
|
op.add_arg()->CopyFrom(MakeArgument("Y_zero_point",
|
|
ArgumentHelper::GetSingleArgument<OperatorDef, int>(*quantOp, "Y_zero_point", -1)));
|
|
|
|
auto quant_net_pos = ArgumentHelper::GetSingleArgument<OperatorDef, int>(
|
|
*quantOp, "net_pos", -1);
|
|
|
|
|
|
quantOp->set_type("delete_me_optimized_away");
|
|
|
|
LOG(INFO) << "Fusing tanh and quant at " << tanh_net_pos << ", " << quant_net_pos;
|
|
}
|
|
}
|
|
}
|
|
|
|
void fakeFp16FuseOps(NetDef* net) {
|
|
LOG(INFO) << "Running Fp16 Fusion";
|
|
|
|
// We should fuse the groups of bigger operators first
|
|
fakeFp16FoldLayerNorm(net);
|
|
fakeFp16FoldSwish(net);
|
|
fakeFp16FoldTanhQuant(net);
|
|
fakeFp16FoldLayerNormQuant(net);
|
|
|
|
auto iter = net->mutable_op()->begin();
|
|
while (iter != net->mutable_op()->end()) {
|
|
if (iter->type() == "delete_me_optimized_away") {
|
|
iter = net->mutable_op()->erase(iter);
|
|
} else {
|
|
++iter;
|
|
}
|
|
}
|
|
}
|
|
|
|
void fakeFp16Transform(NetDef* net) {
|
|
static const std::unordered_map<std::string, std::string>
|
|
kFakeFp16OpConversionMap = getFakeFp16OpMapping(
|
|
FLAGS_fake_fp16_conversion_use_fp16_acc,
|
|
FLAGS_fake_fp16_conversion_use_nnpi);
|
|
|
|
auto blocklist_pos = glow::ParseNetPositionList(FLAGS_onnxifi_blacklist);
|
|
auto blocklist_type = glow::ParseBlockListOps(FLAGS_onnxifi_blacklist_ops);
|
|
|
|
// A hack to only do fakefp16 transformation for operators which will be
|
|
// lowered to ONNXIFI.
|
|
// TODO(yingz): Use more deterministic logics to figure out operators which
|
|
// can be lowered to ONNXIFI instead.
|
|
int last_clip_idx = -1;
|
|
for (int i = 0; i < net->op().size(); ++i) {
|
|
const auto& op = net->op(i);
|
|
if (op.type() == "Clip") {
|
|
last_clip_idx = i;
|
|
}
|
|
}
|
|
for (int i = 0; i < net->op().size(); ++i) {
|
|
if (i <= last_clip_idx) {
|
|
continue;
|
|
}
|
|
auto* op = net->mutable_op(i);
|
|
auto net_pos =
|
|
ArgumentHelper::GetSingleArgument<OperatorDef, int>(*op, "net_pos", -1);
|
|
if (blocklist_pos.count(net_pos) || blocklist_type.count(op->type())) {
|
|
continue;
|
|
}
|
|
auto it = kFakeFp16OpConversionMap.find(op->type());
|
|
if (it != kFakeFp16OpConversionMap.end()) {
|
|
op->set_type(it->second);
|
|
}
|
|
}
|
|
|
|
fakeFp16FuseOps(net);
|
|
}
|
|
|
|
} // namespace opt
|
|
} // namespace caffe2
|