mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant][graphmode][fix] dequantize propagation for {add/mul}_scalar (#40596)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40596 Previously the fusion patterns for {add/mul}_scalar is inconsistent since the op pattern produces a non-quantized tensor and the op replacement graph produces a quantized tensor Test Plan: Imported from OSS Differential Revision: D22251072 fbshipit-source-id: e16eb92cf6611578cca1ed8ebde961f8d0610137
This commit is contained in:
parent
547ea787ff
commit
e3a97688cc
|
|
@ -1103,21 +1103,6 @@ class TestQuantizeJitPasses(QuantizationTestCase):
|
|||
.check("aten::dequantize") \
|
||||
.run(model.graph)
|
||||
|
||||
def test_finalize_no_extra_dequantize(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(M, self).__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 3, 3).float()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x.size(0) * x
|
||||
|
||||
model = torch.jit.script(M()).eval()
|
||||
model = quantize_jit(model, {'': default_qconfig}, test_only_eval_fn, [self.img_data])
|
||||
FileCheck().check_not("aten::dequantize(") \
|
||||
.run(model.graph)
|
||||
|
||||
def test_module_list(self):
|
||||
class SimpleLinearLayer(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
@ -1757,7 +1742,6 @@ class TestQuantizeJitOps(QuantizationTestCase):
|
|||
for tracing in [True, False]:
|
||||
# quantized::add_scalar_relu or quantized::add_scalar_relu_out
|
||||
# TODO: split this after refactor of checkGraphModeOp
|
||||
# TODO: fix debug=True numerics
|
||||
m = self.checkGraphModeOp(m, data, "quantized::add_scalar_relu", tracing, check=False)
|
||||
FileCheck().check_not("aten::add(") \
|
||||
.check_not("aten::add_(") \
|
||||
|
|
@ -1794,7 +1778,7 @@ class TestQuantizeJitOps(QuantizationTestCase):
|
|||
torch.randn(1, 2, 5, 5, dtype=torch.float),
|
||||
torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
|
||||
for tracing in [True, False]:
|
||||
m = self.checkGraphModeOp(QuantizedCat(), data, "quantized::cat", tracing)
|
||||
m = self.checkGraphModeOp(QuantizedCat(), data, "quantized::cat", tracing, debug=True)
|
||||
FileCheck().check_not("aten::cat") \
|
||||
.run(m.graph)
|
||||
|
||||
|
|
@ -2139,7 +2123,6 @@ class TestQuantizeJitOps(QuantizationTestCase):
|
|||
InplaceMulScalarInplaceFunctionalRelu()]:
|
||||
for tracing in [True, False]:
|
||||
# quantized::mul_scalar_relu or quantized::mul_scalar_relu_out
|
||||
# TODO: fix debug=True numerics
|
||||
m = self.checkGraphModeOp(m, data, "quantized::mul_scalar_relu", tracing, check=False)
|
||||
FileCheck().check_not("aten::mul(") \
|
||||
.check_not("aten::mul_(") \
|
||||
|
|
@ -2272,6 +2255,26 @@ class TestQuantizeJitOps(QuantizationTestCase):
|
|||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
# add_scalar
|
||||
x = x + 3
|
||||
# mul_scalar
|
||||
x = x * 3
|
||||
# add_scalar_out
|
||||
x += 3
|
||||
# mul_scalar_out
|
||||
x *= 3
|
||||
# add_scalar_relu
|
||||
x = x + 3
|
||||
x = F.relu(x)
|
||||
# add_scalar_relu_out
|
||||
x += 3
|
||||
x = F.relu(x)
|
||||
# mul_scalar_relu
|
||||
x = x * 3
|
||||
x = F.relu(x)
|
||||
# mul_scalar_relu_out
|
||||
x *= 3
|
||||
x = F.relu(x)
|
||||
x = self.maxpool1d(x)
|
||||
x = self.maxpool2d(x)
|
||||
x = self.maxpool3d(x)
|
||||
|
|
@ -2332,11 +2335,16 @@ class TestQuantizeJitOps(QuantizationTestCase):
|
|||
# observers and also successfully fused two quantized::conv2d
|
||||
# patterns
|
||||
# one quantize_per_tensor for input
|
||||
# TODO: the checks are problematic, we need to split all checks
|
||||
FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True) \
|
||||
.check_count("quantized::conv2d", 2, exactly=True) \
|
||||
.check("aten::dequantize") \
|
||||
.run(m.graph)
|
||||
|
||||
FileCheck().check("quantized::add_scalar") \
|
||||
.check("quantized::mul_scalar") \
|
||||
.run(m.graph)
|
||||
|
||||
def test_general_value_ops(self):
|
||||
""" A test that checks correct patterns are produced for
|
||||
all supported general value ops like aten::avg_pool2d \
|
||||
|
|
|
|||
|
|
@ -353,8 +353,13 @@ bool isFunctionNode(
|
|||
return is_func_node;
|
||||
}
|
||||
|
||||
bool isSingleInputGeneralShapeAtenFunction(Node* n) {
|
||||
return isAtenFunc(n, _single_input_general_shape_aten_funcs);
|
||||
}
|
||||
|
||||
bool isSingleInputGeneralValueAtenFunction(Node* n) {
|
||||
return isAtenFunc(n, _single_input_general_value_aten_funcs);
|
||||
return isAtenFunc(n, _single_input_general_value_aten_funcs) ||
|
||||
isBinaryOpWithScalarInput(n);
|
||||
}
|
||||
|
||||
bool isSingleInputGeneralCallFunction(Node* n) {
|
||||
|
|
@ -381,8 +386,8 @@ bool isSingleInputGeneralAtenFunction(Node* n) {
|
|||
std::back_inserter(fixed_qparams_aten_funcs),
|
||||
[](auto pair) { return pair.first; });
|
||||
|
||||
return isAtenFunc(n, _single_input_general_shape_aten_funcs) ||
|
||||
isAtenFunc(n, _single_input_general_value_aten_funcs) ||
|
||||
return isSingleInputGeneralValueAtenFunction(n) ||
|
||||
isSingleInputGeneralShapeAtenFunction(n) ||
|
||||
isAtenFunc(n, fixed_qparams_aten_funcs);
|
||||
}
|
||||
|
||||
|
|
@ -406,6 +411,10 @@ bool isPropagateQuantOp(Node* n) {
|
|||
return isPropagateQuantSingleInputOp(n) || isPropagateQuantBinaryOp(n);
|
||||
}
|
||||
|
||||
bool isBinaryOpWithScalarInput(Node* n) {
|
||||
return isPropagateQuantBinaryOp(n) && isScalar(n->input(1));
|
||||
}
|
||||
|
||||
c10::optional<std::tuple<c10::QScheme, QParamVector>> getFixedQParams(Node* n) {
|
||||
static std::vector<NodeKind> fixed_qparam_funcs;
|
||||
std::transform(
|
||||
|
|
|
|||
|
|
@ -45,6 +45,8 @@ TORCH_API bool isScalar(Value* v);
|
|||
TORCH_API bool hitGraphInput(Value* value);
|
||||
|
||||
// =========== helper functions for Node =========
|
||||
TORCH_API bool isSingleInputGeneralShapeAtenFunction(Node* n);
|
||||
|
||||
TORCH_API bool isSingleInputGeneralValueAtenFunction(Node* n);
|
||||
|
||||
TORCH_API bool isSingleInputGeneralCallFunction(Node* n);
|
||||
|
|
@ -67,6 +69,11 @@ TORCH_API bool isPropagateQuantBinaryOp(Node* n);
|
|||
// whether the input of the node is quantized, example: aten::cat
|
||||
TORCH_API bool isPropagateQuantOp(Node* n);
|
||||
|
||||
// Check if the node is a binary op like aten::add and aten::mul and
|
||||
// if the input 1 is a scalar, these ops will be quantized to
|
||||
// quantized::{op}_scalar
|
||||
TORCH_API bool isBinaryOpWithScalarInput(Node* n);
|
||||
|
||||
TORCH_API c10::optional<std::tuple<c10::QScheme, QParamVector>> getFixedQParams(
|
||||
Node* n);
|
||||
|
||||
|
|
|
|||
|
|
@ -1019,16 +1019,6 @@ void InsertQuantDeQuantHelper::propagateQuantizationOps(Block* block) {
|
|||
propagateQParams(output, *inputs, /* is_scalar */ false, qparams_opt);
|
||||
}
|
||||
}
|
||||
} else if (isPropagateQuantBinaryOp(n)) {
|
||||
// Print warning for add_scalar/mul_scalar when debug is enabled
|
||||
// since the quantization parameter for these ops depends on
|
||||
// input and it's too complicated to encode the equations in
|
||||
// the IR:
|
||||
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/quantized/cpu/qadd.cpp#L64-L74
|
||||
if (debug_ && isScalar(n->input(1))) {
|
||||
TORCH_WARN_ONCE(
|
||||
"debug option for add_scalar and mul_scalar is not supported");
|
||||
}
|
||||
} else {
|
||||
// For ops that are quantized by propagating dequantize ops,
|
||||
// e.g. flatten we need to
|
||||
|
|
@ -1064,6 +1054,19 @@ void InsertQuantDeQuantHelper::propagateQuantizationOps(Block* block) {
|
|||
insertDeQuantForAllUse(output->owningGraph(), output, output);
|
||||
}
|
||||
}
|
||||
|
||||
if (isBinaryOpWithScalarInput(n)) {
|
||||
// Print warning for add_scalar/mul_scalar when debug is enabled
|
||||
// since the quantization parameter for these ops depends on
|
||||
// input and it's too complicated to encode the equations in
|
||||
// the IR:
|
||||
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/quantized/cpu/qadd.cpp#L64-L74
|
||||
if (debug_) {
|
||||
TORCH_WARN_ONCE(
|
||||
"debug option for add_scalar and mul_scalar is not supported, "
|
||||
"please don't use debug option for models that uses these ops.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -96,28 +96,54 @@ std::string getItem(const std::string& value) {
|
|||
}
|
||||
|
||||
// Patterns for the ops that inherit parameters from input
|
||||
QuantFusionInfo getInputTensorQParamOpFusionInfo(
|
||||
std::string getInputTensorQParamOpPattern(
|
||||
const std::string& op_name,
|
||||
const std::vector<std::string>& extra_op_args) {
|
||||
const auto& extra_op_arg_list = getExtraArgList(extra_op_args);
|
||||
std::string graph_header = "graph(%a_quant" + extra_op_arg_list + "):";
|
||||
std::string op_pattern = graph_header;
|
||||
op_pattern += R"(
|
||||
std::string op_pattern = "graph(%a_quant" + extra_op_arg_list + "):" + R"(
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%r = )";
|
||||
op_pattern += op_name + "(" + "%a_dequant" + extra_op_arg_list + ")";
|
||||
// IR pattern common to all ops that inherit qparam from input
|
||||
op_pattern += R"(
|
||||
%r = )" +
|
||||
op_name + "(" + "%a_dequant" + extra_op_arg_list + ")" + R"(
|
||||
%r_scale : float = aten::q_scale(%a_quant)
|
||||
%r_zero_point : int = aten::q_zero_point(%a_quant)
|
||||
%r_dtype : int = prim::dtype(%a_quant)
|
||||
%r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
|
||||
return (%r_quant) )";
|
||||
return op_pattern;
|
||||
}
|
||||
|
||||
std::string aten_op_pattern =
|
||||
// QuantFusionInfo for the ops that inherit parameters from input
|
||||
QuantFusionInfo getInputTensorQParamOpFusionInfo(
|
||||
const std::string& op_name,
|
||||
const std::vector<std::string>& extra_op_args) {
|
||||
std::string op_pattern =
|
||||
getInputTensorQParamOpPattern(op_name, extra_op_args);
|
||||
const auto& extra_op_arg_list = getExtraArgList(extra_op_args);
|
||||
std::string graph_header = "graph(%a_quant" + extra_op_arg_list + "):";
|
||||
std::string op_replacement =
|
||||
getAtenOpPattern(graph_header, op_name, extra_op_args);
|
||||
|
||||
return {op_name, op_pattern, aten_op_pattern};
|
||||
return {op_name, op_pattern, op_replacement};
|
||||
}
|
||||
|
||||
// quant fusion for ops like `quantized::add_scalar`, `quantized::mul_scalar`
|
||||
QuantFusionInfo getBinaryOpScalarFusionInfo(
|
||||
const std::string& op_name,
|
||||
const std::vector<std::string>& extra_op_args,
|
||||
const std::string& quantized_op_name,
|
||||
const std::vector<std::string>& extra_quantized_op_args,
|
||||
const std::vector<MatchFilter>& filters = {}) {
|
||||
std::string op_pattern =
|
||||
getInputTensorQParamOpPattern(op_name, extra_op_args);
|
||||
|
||||
const auto& extra_op_arg_list = getExtraArgList(extra_op_args);
|
||||
std::string graph_header = "graph(%a_quant" + extra_op_arg_list + "):";
|
||||
const auto& extra_quantized_op_arg_list =
|
||||
getExtraArgList(extra_quantized_op_args);
|
||||
std::string op_replacement = getAtenOpPattern(
|
||||
graph_header, quantized_op_name, extra_quantized_op_args);
|
||||
|
||||
return {op_name, op_pattern, op_replacement, filters};
|
||||
}
|
||||
|
||||
QuantFusionInfo getClampOpFusionInfo(
|
||||
|
|
@ -504,67 +530,55 @@ graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
|
|||
%r = aten::quantize_per_tensor(%r_add, %scale, %zero_point, %dtype)
|
||||
return (%r) )";
|
||||
|
||||
// quantized::add_scalar
|
||||
std::string add_scalar = R"(
|
||||
graph(%a_quant, %b_scalar, %alpha):
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%r = aten::add(%a_dequant, %b_scalar, %alpha)
|
||||
return (%r) )";
|
||||
auto add_scalar = getBinaryOpScalarFusionInfo(
|
||||
"aten::add",
|
||||
{"%b_scalar", "%alpha"},
|
||||
"quantized::add_scalar",
|
||||
{"%b_scalar"},
|
||||
{aten_add_alpha_is_one, input_b_is_scalar});
|
||||
|
||||
std::string quantized_add_scalar = R"(
|
||||
graph(%a_quant, %b_scalar, %alpha):
|
||||
%r = quantized::add_scalar(%a_quant, %b_scalar)
|
||||
return (%r) )";
|
||||
auto add_scalar_out = getBinaryOpScalarFusionInfo(
|
||||
"aten::add_",
|
||||
{"%b_scalar", "%alpha"},
|
||||
"quantized::add_scalar_out",
|
||||
{"%b_scalar", "%a_quant"},
|
||||
{aten_add_alpha_is_one, input_b_is_scalar});
|
||||
|
||||
// quantized::add_scalar_out
|
||||
std::string inplace_add_scalar = R"(
|
||||
graph(%a_quant, %b_scalar, %alpha):
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%r = aten::add_(%a_dequant, %b_scalar, %alpha)
|
||||
return (%r) )";
|
||||
|
||||
std::string quantized_add_scalar_out = R"(
|
||||
graph(%a_quant, %b_scalar, %alpha):
|
||||
%r = quantized::add_scalar_out(%a_quant, %b_scalar, %a_quant)
|
||||
return (%r) )";
|
||||
|
||||
// quantized::add_scalar_relu
|
||||
std::string add_scalar_relu = R"(
|
||||
graph(%a_quant, %b_scalar, %alpha):
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%r_add = aten::add(%a_dequant, %b_scalar, %alpha)
|
||||
// quantized::add_scalar_relu -- fusing quantized::add_scalar
|
||||
// and aten::relu
|
||||
auto quantized_add_scalar_relu_pattern = R"(
|
||||
graph(%a_quant, %b_scalar):
|
||||
%r_add = quantized::add_scalar(%a_quant, %b_scalar)
|
||||
%r = aten::relu(%r_add)
|
||||
return (%r) )";
|
||||
|
||||
std::string add_scalar_inplace_relu = R"(
|
||||
graph(%a_quant, %b_scalar, %alpha):
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%r_add = aten::add(%a_dequant, %b_scalar, %alpha)
|
||||
auto quantized_add_scalar_inplace_relu_pattern = R"(
|
||||
graph(%a_quant, %b_scalar):
|
||||
%r_add = quantized::add_scalar(%a_quant, %b_scalar)
|
||||
%r = aten::relu_(%r_add)
|
||||
return (%r) )";
|
||||
|
||||
std::string quantized_add_scalar_relu = R"(
|
||||
graph(%a_quant, %b_scalar, %alpha):
|
||||
auto quantized_add_scalar_relu_replacement = R"(
|
||||
graph(%a_quant, %b_scalar):
|
||||
%r = quantized::add_scalar_relu(%a_quant, %b_scalar)
|
||||
return (%r) )";
|
||||
|
||||
// quantized::add_scalar_relu_out
|
||||
std::string inplace_add_scalar_relu = R"(
|
||||
graph(%a_quant, %b_scalar, %alpha):
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%r_add = aten::add_(%a_dequant, %b_scalar, %alpha)
|
||||
// quantized::add_scalar_relu_out -- fusing quantized::add_scalarOut
|
||||
// and aten::relu
|
||||
auto quantized_add_scalar_relu_out_pattern = R"(
|
||||
graph(%a_quant, %b_scalar):
|
||||
%r_add = quantized::add_scalar_out(%a_quant, %b_scalar, %a_quant)
|
||||
%r = aten::relu(%r_add)
|
||||
return (%r) )";
|
||||
|
||||
std::string inplace_add_scalar_inplace_relu = R"(
|
||||
graph(%a_quant, %b_scalar, %alpha):
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%r_add = aten::add_(%a_dequant, %b_scalar, %alpha)
|
||||
auto quantized_add_scalar_inplace_relu_out_pattern = R"(
|
||||
graph(%a_quant, %b_scalar):
|
||||
%r_add = quantized::add_scalar_out(%a_quant, %b_scalar, %a_quant)
|
||||
%r = aten::relu_(%r_add)
|
||||
return (%r) )";
|
||||
|
||||
std::string quantized_add_scalar_relu_out = R"(
|
||||
graph(%a_quant, %b_scalar, %alpha):
|
||||
auto quantized_add_scalar_relu_out_replacement = R"(
|
||||
graph(%a_quant, %b_scalar):
|
||||
%r = quantized::add_scalar_relu_out(%a_quant, %b_scalar, %a_quant)
|
||||
return (%r) )";
|
||||
|
||||
|
|
@ -624,28 +638,19 @@ graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
|
|||
%r = quantized::mul(%a_quant, %b_quant, %scale, %zero_point)
|
||||
return (%r) )";
|
||||
|
||||
// quantized::mul_scalar
|
||||
std::string mul_scalar = R"(
|
||||
graph(%a_quant, %b_scalar):
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%r = aten::mul(%a_dequant, %b_scalar)
|
||||
return (%r) )";
|
||||
auto mul_scalar = getBinaryOpScalarFusionInfo(
|
||||
"aten::mul",
|
||||
{"%b_scalar"},
|
||||
"quantized::mul_scalar",
|
||||
{"%b_scalar"},
|
||||
{input_b_is_scalar});
|
||||
|
||||
std::string inplace_mul_scalar = R"(
|
||||
graph(%a_quant, %b_scalar):
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%r = aten::mul_(%a_dequant, %b_scalar)
|
||||
return (%r) )";
|
||||
|
||||
std::string quantized_mul_scalar = R"(
|
||||
graph(%a_quant, %b_scalar):
|
||||
%r = quantized::mul_scalar(%a_quant, %b_scalar)
|
||||
return (%r) )";
|
||||
|
||||
std::string quantized_mul_scalar_out = R"(
|
||||
graph(%a_quant, %b_scalar):
|
||||
%r = quantized::mul_scalar_out(%a_quant, %b_scalar, %a_quant)
|
||||
return (%r) )";
|
||||
auto mul_scalar_out = getBinaryOpScalarFusionInfo(
|
||||
"aten::mul_",
|
||||
{"%b_scalar"},
|
||||
"quantized::mul_scalar_out",
|
||||
{"%b_scalar", "%a_quant"},
|
||||
{input_b_is_scalar});
|
||||
|
||||
// quantized::mul_relu
|
||||
std::string mul_relu = R"(
|
||||
|
|
@ -689,42 +694,40 @@ graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
|
|||
%r = quantized::mul_relu(%a_quant, %b_quant, %scale, %zero_point)
|
||||
return (%r) )";
|
||||
|
||||
// quantized::mul_scalar_relu
|
||||
std::string mul_scalar_relu = R"(
|
||||
// quantized::mul_scalar_relu -- fusing quantized::mul_scalar
|
||||
// and aten::relu
|
||||
auto quantized_mul_scalar_relu_pattern = R"(
|
||||
graph(%a_quant, %b_scalar):
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%r_mul = aten::mul(%a_dequant, %b_scalar)
|
||||
%r_mul = quantized::mul_scalar(%a_quant, %b_scalar)
|
||||
%r = aten::relu(%r_mul)
|
||||
return (%r) )";
|
||||
|
||||
std::string mul_scalar_inplace_relu = R"(
|
||||
auto quantized_mul_scalar_inplace_relu_pattern = R"(
|
||||
graph(%a_quant, %b_scalar):
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%r_mul = aten::mul(%a_dequant, %b_scalar)
|
||||
%r_mul = quantized::mul_scalar(%a_quant, %b_scalar)
|
||||
%r = aten::relu_(%r_mul)
|
||||
return (%r) )";
|
||||
|
||||
std::string quantized_mul_scalar_relu = R"(
|
||||
auto quantized_mul_scalar_relu_replacement = R"(
|
||||
graph(%a_quant, %b_scalar):
|
||||
%r = quantized::mul_scalar_relu(%a_quant, %b_scalar)
|
||||
return (%r) )";
|
||||
|
||||
// quantized::mul_scalar_relu_out
|
||||
std::string inplace_mul_scalar_relu = R"(
|
||||
// quantized::mul_scalar_relu_out -- fusing quantized::mul_scalarOut
|
||||
// and aten::relu
|
||||
auto quantized_mul_scalar_relu_out_pattern = R"(
|
||||
graph(%a_quant, %b_scalar):
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%r_mul = aten::mul_(%a_dequant, %b_scalar)
|
||||
%r_mul = quantized::mul_scalar_out(%a_quant, %b_scalar, %a_quant)
|
||||
%r = aten::relu(%r_mul)
|
||||
return (%r) )";
|
||||
|
||||
std::string inplace_mul_scalar_inplace_relu = R"(
|
||||
auto quantized_mul_scalar_inplace_relu_out_pattern = R"(
|
||||
graph(%a_quant, %b_scalar):
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%r_mul = aten::mul_(%a_dequant, %b_scalar)
|
||||
%r_mul = quantized::mul_scalar_out(%a_quant, %b_scalar, %a_quant)
|
||||
%r = aten::relu_(%r_mul)
|
||||
return (%r) )";
|
||||
|
||||
std::string quantized_mul_scalar_relu_out = R"(
|
||||
auto quantized_mul_scalar_relu_out_replacement = R"(
|
||||
graph(%a_quant, %b_scalar):
|
||||
%r = quantized::mul_scalar_relu_out(%a_quant, %b_scalar, %a_quant)
|
||||
return (%r) )";
|
||||
|
|
@ -905,31 +908,22 @@ graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype)
|
|||
inplace_add_inplace_relu,
|
||||
quantized_add_relu,
|
||||
{aten_add_alpha_is_one}},
|
||||
// note that this must come before quantized::add_scalar
|
||||
add_scalar,
|
||||
add_scalar_out,
|
||||
// note that these must come after quantized::add_scalar and
|
||||
// quantized::add_scalar_out patterns
|
||||
{"quantized::add_scalar_relu",
|
||||
add_scalar_relu,
|
||||
quantized_add_scalar_relu,
|
||||
{aten_add_alpha_is_one, input_b_is_scalar}},
|
||||
quantized_add_scalar_relu_pattern,
|
||||
quantized_add_scalar_relu_replacement},
|
||||
{"quantized::add_scalar_relu",
|
||||
add_scalar_inplace_relu,
|
||||
quantized_add_scalar_relu,
|
||||
{aten_add_alpha_is_one, input_b_is_scalar}},
|
||||
quantized_add_scalar_inplace_relu_pattern,
|
||||
quantized_add_scalar_relu_replacement},
|
||||
{"quantized::add_scalar_relu_out",
|
||||
inplace_add_scalar_relu,
|
||||
quantized_add_scalar_relu_out,
|
||||
{aten_add_alpha_is_one, input_b_is_scalar}},
|
||||
quantized_add_scalar_relu_out_pattern,
|
||||
quantized_add_scalar_relu_out_replacement},
|
||||
{"quantized::add_scalar_relu_out",
|
||||
inplace_add_scalar_inplace_relu,
|
||||
quantized_add_scalar_relu_out,
|
||||
{aten_add_alpha_is_one, input_b_is_scalar}},
|
||||
{"quantized::add_scalar",
|
||||
add_scalar,
|
||||
quantized_add_scalar,
|
||||
{aten_add_alpha_is_one, input_b_is_scalar}},
|
||||
{"quantized::add_scalar_out",
|
||||
inplace_add_scalar,
|
||||
quantized_add_scalar_out,
|
||||
{aten_add_alpha_is_one, input_b_is_scalar}},
|
||||
quantized_add_scalar_inplace_relu_out_pattern,
|
||||
quantized_add_scalar_relu_out_replacement},
|
||||
{"quantized::add", add, quantized_add, {aten_add_alpha_is_one}},
|
||||
{"quantized::add", inplace_add, quantized_add, {aten_add_alpha_is_one}},
|
||||
{"quantized::cat", cat, quantized_cat},
|
||||
|
|
@ -940,30 +934,22 @@ graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype)
|
|||
{"quantized::batch_norm_relu",
|
||||
batch_norm_inplace_relu,
|
||||
quantized_batch_norm_relu},
|
||||
mul_scalar,
|
||||
mul_scalar_out,
|
||||
// note that these must come after quantized::mul_scalar and
|
||||
// quantized::mul_scalar_out patterns
|
||||
{"quantized::mul_scalar_relu",
|
||||
mul_scalar_relu,
|
||||
quantized_mul_scalar_relu,
|
||||
{input_b_is_scalar}},
|
||||
quantized_mul_scalar_relu_pattern,
|
||||
quantized_mul_scalar_relu_replacement},
|
||||
{"quantized::mul_scalar_relu",
|
||||
mul_scalar_inplace_relu,
|
||||
quantized_mul_scalar_relu,
|
||||
{input_b_is_scalar}},
|
||||
quantized_mul_scalar_inplace_relu_pattern,
|
||||
quantized_mul_scalar_relu_replacement},
|
||||
{"quantized::mul_scalar_relu_out",
|
||||
inplace_mul_scalar_relu,
|
||||
quantized_mul_scalar_relu_out,
|
||||
{input_b_is_scalar}},
|
||||
quantized_mul_scalar_relu_out_pattern,
|
||||
quantized_mul_scalar_relu_out_replacement},
|
||||
{"quantized::mul_scalar_relu_out",
|
||||
inplace_mul_scalar_inplace_relu,
|
||||
quantized_mul_scalar_relu_out,
|
||||
{input_b_is_scalar}},
|
||||
{"quantized::mul_scalar",
|
||||
mul_scalar,
|
||||
quantized_mul_scalar,
|
||||
{input_b_is_scalar}},
|
||||
{"quantized::mul_scalar",
|
||||
inplace_mul_scalar,
|
||||
quantized_mul_scalar_out,
|
||||
{input_b_is_scalar}},
|
||||
quantized_mul_scalar_inplace_relu_out_pattern,
|
||||
quantized_mul_scalar_relu_out_replacement},
|
||||
{"quantized::mul_relu", mul_relu, quantized_mul_relu},
|
||||
{"quantized::mul_relu", mul_inplace_relu, quantized_mul_relu},
|
||||
{"quantized::mul_relu", inplace_mul_relu, quantized_mul_relu},
|
||||
|
|
|
|||
|
|
@ -308,7 +308,7 @@ RegisterOperators reg(
|
|||
return 0;
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
// only used internally in range() translation
|
||||
// only used internally in range() translation
|
||||
Operator(
|
||||
"aten::__range_length(int lo, int hi, int step) -> int",
|
||||
[](Stack& stack) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user