[quant][graphmode][fix] dequantize propagation for {add/mul}_scalar (#40596)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40596

Previously the fusion patterns for {add/mul}_scalar is inconsistent since the op pattern
produces a non-quantized tensor and the op replacement graph produces a quantized tensor

Test Plan: Imported from OSS

Differential Revision: D22251072

fbshipit-source-id: e16eb92cf6611578cca1ed8ebde961f8d0610137
This commit is contained in:
Jerry Zhang 2020-06-25 22:15:36 -07:00 committed by Facebook GitHub Bot
parent 547ea787ff
commit e3a97688cc
6 changed files with 177 additions and 164 deletions

View File

@ -1103,21 +1103,6 @@ class TestQuantizeJitPasses(QuantizationTestCase):
.check("aten::dequantize") \
.run(model.graph)
def test_finalize_no_extra_dequantize(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(3, 3, 3).float()
def forward(self, x):
x = self.conv(x)
return x.size(0) * x
model = torch.jit.script(M()).eval()
model = quantize_jit(model, {'': default_qconfig}, test_only_eval_fn, [self.img_data])
FileCheck().check_not("aten::dequantize(") \
.run(model.graph)
def test_module_list(self):
class SimpleLinearLayer(torch.nn.Module):
def __init__(self):
@ -1757,7 +1742,6 @@ class TestQuantizeJitOps(QuantizationTestCase):
for tracing in [True, False]:
# quantized::add_scalar_relu or quantized::add_scalar_relu_out
# TODO: split this after refactor of checkGraphModeOp
# TODO: fix debug=True numerics
m = self.checkGraphModeOp(m, data, "quantized::add_scalar_relu", tracing, check=False)
FileCheck().check_not("aten::add(") \
.check_not("aten::add_(") \
@ -1794,7 +1778,7 @@ class TestQuantizeJitOps(QuantizationTestCase):
torch.randn(1, 2, 5, 5, dtype=torch.float),
torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
for tracing in [True, False]:
m = self.checkGraphModeOp(QuantizedCat(), data, "quantized::cat", tracing)
m = self.checkGraphModeOp(QuantizedCat(), data, "quantized::cat", tracing, debug=True)
FileCheck().check_not("aten::cat") \
.run(m.graph)
@ -2139,7 +2123,6 @@ class TestQuantizeJitOps(QuantizationTestCase):
InplaceMulScalarInplaceFunctionalRelu()]:
for tracing in [True, False]:
# quantized::mul_scalar_relu or quantized::mul_scalar_relu_out
# TODO: fix debug=True numerics
m = self.checkGraphModeOp(m, data, "quantized::mul_scalar_relu", tracing, check=False)
FileCheck().check_not("aten::mul(") \
.check_not("aten::mul_(") \
@ -2272,6 +2255,26 @@ class TestQuantizeJitOps(QuantizationTestCase):
def forward(self, x):
x = self.conv(x)
# add_scalar
x = x + 3
# mul_scalar
x = x * 3
# add_scalar_out
x += 3
# mul_scalar_out
x *= 3
# add_scalar_relu
x = x + 3
x = F.relu(x)
# add_scalar_relu_out
x += 3
x = F.relu(x)
# mul_scalar_relu
x = x * 3
x = F.relu(x)
# mul_scalar_relu_out
x *= 3
x = F.relu(x)
x = self.maxpool1d(x)
x = self.maxpool2d(x)
x = self.maxpool3d(x)
@ -2332,11 +2335,16 @@ class TestQuantizeJitOps(QuantizationTestCase):
# observers and also successfully fused two quantized::conv2d
# patterns
# one quantize_per_tensor for input
# TODO: the checks are problematic, we need to split all checks
FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True) \
.check_count("quantized::conv2d", 2, exactly=True) \
.check("aten::dequantize") \
.run(m.graph)
FileCheck().check("quantized::add_scalar") \
.check("quantized::mul_scalar") \
.run(m.graph)
def test_general_value_ops(self):
""" A test that checks correct patterns are produced for
all supported general value ops like aten::avg_pool2d \

View File

@ -353,8 +353,13 @@ bool isFunctionNode(
return is_func_node;
}
bool isSingleInputGeneralShapeAtenFunction(Node* n) {
return isAtenFunc(n, _single_input_general_shape_aten_funcs);
}
bool isSingleInputGeneralValueAtenFunction(Node* n) {
return isAtenFunc(n, _single_input_general_value_aten_funcs);
return isAtenFunc(n, _single_input_general_value_aten_funcs) ||
isBinaryOpWithScalarInput(n);
}
bool isSingleInputGeneralCallFunction(Node* n) {
@ -381,8 +386,8 @@ bool isSingleInputGeneralAtenFunction(Node* n) {
std::back_inserter(fixed_qparams_aten_funcs),
[](auto pair) { return pair.first; });
return isAtenFunc(n, _single_input_general_shape_aten_funcs) ||
isAtenFunc(n, _single_input_general_value_aten_funcs) ||
return isSingleInputGeneralValueAtenFunction(n) ||
isSingleInputGeneralShapeAtenFunction(n) ||
isAtenFunc(n, fixed_qparams_aten_funcs);
}
@ -406,6 +411,10 @@ bool isPropagateQuantOp(Node* n) {
return isPropagateQuantSingleInputOp(n) || isPropagateQuantBinaryOp(n);
}
bool isBinaryOpWithScalarInput(Node* n) {
return isPropagateQuantBinaryOp(n) && isScalar(n->input(1));
}
c10::optional<std::tuple<c10::QScheme, QParamVector>> getFixedQParams(Node* n) {
static std::vector<NodeKind> fixed_qparam_funcs;
std::transform(

View File

@ -45,6 +45,8 @@ TORCH_API bool isScalar(Value* v);
TORCH_API bool hitGraphInput(Value* value);
// =========== helper functions for Node =========
TORCH_API bool isSingleInputGeneralShapeAtenFunction(Node* n);
TORCH_API bool isSingleInputGeneralValueAtenFunction(Node* n);
TORCH_API bool isSingleInputGeneralCallFunction(Node* n);
@ -67,6 +69,11 @@ TORCH_API bool isPropagateQuantBinaryOp(Node* n);
// whether the input of the node is quantized, example: aten::cat
TORCH_API bool isPropagateQuantOp(Node* n);
// Check if the node is a binary op like aten::add and aten::mul and
// if the input 1 is a scalar, these ops will be quantized to
// quantized::{op}_scalar
TORCH_API bool isBinaryOpWithScalarInput(Node* n);
TORCH_API c10::optional<std::tuple<c10::QScheme, QParamVector>> getFixedQParams(
Node* n);

View File

@ -1019,16 +1019,6 @@ void InsertQuantDeQuantHelper::propagateQuantizationOps(Block* block) {
propagateQParams(output, *inputs, /* is_scalar */ false, qparams_opt);
}
}
} else if (isPropagateQuantBinaryOp(n)) {
// Print warning for add_scalar/mul_scalar when debug is enabled
// since the quantization parameter for these ops depends on
// input and it's too complicated to encode the equations in
// the IR:
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/quantized/cpu/qadd.cpp#L64-L74
if (debug_ && isScalar(n->input(1))) {
TORCH_WARN_ONCE(
"debug option for add_scalar and mul_scalar is not supported");
}
} else {
// For ops that are quantized by propagating dequantize ops,
// e.g. flatten we need to
@ -1064,6 +1054,19 @@ void InsertQuantDeQuantHelper::propagateQuantizationOps(Block* block) {
insertDeQuantForAllUse(output->owningGraph(), output, output);
}
}
if (isBinaryOpWithScalarInput(n)) {
// Print warning for add_scalar/mul_scalar when debug is enabled
// since the quantization parameter for these ops depends on
// input and it's too complicated to encode the equations in
// the IR:
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/quantized/cpu/qadd.cpp#L64-L74
if (debug_) {
TORCH_WARN_ONCE(
"debug option for add_scalar and mul_scalar is not supported, "
"please don't use debug option for models that uses these ops.");
}
}
}
}

View File

@ -96,28 +96,54 @@ std::string getItem(const std::string& value) {
}
// Patterns for the ops that inherit parameters from input
QuantFusionInfo getInputTensorQParamOpFusionInfo(
std::string getInputTensorQParamOpPattern(
const std::string& op_name,
const std::vector<std::string>& extra_op_args) {
const auto& extra_op_arg_list = getExtraArgList(extra_op_args);
std::string graph_header = "graph(%a_quant" + extra_op_arg_list + "):";
std::string op_pattern = graph_header;
op_pattern += R"(
std::string op_pattern = "graph(%a_quant" + extra_op_arg_list + "):" + R"(
%a_dequant = aten::dequantize(%a_quant)
%r = )";
op_pattern += op_name + "(" + "%a_dequant" + extra_op_arg_list + ")";
// IR pattern common to all ops that inherit qparam from input
op_pattern += R"(
%r = )" +
op_name + "(" + "%a_dequant" + extra_op_arg_list + ")" + R"(
%r_scale : float = aten::q_scale(%a_quant)
%r_zero_point : int = aten::q_zero_point(%a_quant)
%r_dtype : int = prim::dtype(%a_quant)
%r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
return (%r_quant) )";
return op_pattern;
}
std::string aten_op_pattern =
// QuantFusionInfo for the ops that inherit parameters from input
QuantFusionInfo getInputTensorQParamOpFusionInfo(
const std::string& op_name,
const std::vector<std::string>& extra_op_args) {
std::string op_pattern =
getInputTensorQParamOpPattern(op_name, extra_op_args);
const auto& extra_op_arg_list = getExtraArgList(extra_op_args);
std::string graph_header = "graph(%a_quant" + extra_op_arg_list + "):";
std::string op_replacement =
getAtenOpPattern(graph_header, op_name, extra_op_args);
return {op_name, op_pattern, aten_op_pattern};
return {op_name, op_pattern, op_replacement};
}
// quant fusion for ops like `quantized::add_scalar`, `quantized::mul_scalar`
QuantFusionInfo getBinaryOpScalarFusionInfo(
const std::string& op_name,
const std::vector<std::string>& extra_op_args,
const std::string& quantized_op_name,
const std::vector<std::string>& extra_quantized_op_args,
const std::vector<MatchFilter>& filters = {}) {
std::string op_pattern =
getInputTensorQParamOpPattern(op_name, extra_op_args);
const auto& extra_op_arg_list = getExtraArgList(extra_op_args);
std::string graph_header = "graph(%a_quant" + extra_op_arg_list + "):";
const auto& extra_quantized_op_arg_list =
getExtraArgList(extra_quantized_op_args);
std::string op_replacement = getAtenOpPattern(
graph_header, quantized_op_name, extra_quantized_op_args);
return {op_name, op_pattern, op_replacement, filters};
}
QuantFusionInfo getClampOpFusionInfo(
@ -504,67 +530,55 @@ graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
%r = aten::quantize_per_tensor(%r_add, %scale, %zero_point, %dtype)
return (%r) )";
// quantized::add_scalar
std::string add_scalar = R"(
graph(%a_quant, %b_scalar, %alpha):
%a_dequant = aten::dequantize(%a_quant)
%r = aten::add(%a_dequant, %b_scalar, %alpha)
return (%r) )";
auto add_scalar = getBinaryOpScalarFusionInfo(
"aten::add",
{"%b_scalar", "%alpha"},
"quantized::add_scalar",
{"%b_scalar"},
{aten_add_alpha_is_one, input_b_is_scalar});
std::string quantized_add_scalar = R"(
graph(%a_quant, %b_scalar, %alpha):
%r = quantized::add_scalar(%a_quant, %b_scalar)
return (%r) )";
auto add_scalar_out = getBinaryOpScalarFusionInfo(
"aten::add_",
{"%b_scalar", "%alpha"},
"quantized::add_scalar_out",
{"%b_scalar", "%a_quant"},
{aten_add_alpha_is_one, input_b_is_scalar});
// quantized::add_scalar_out
std::string inplace_add_scalar = R"(
graph(%a_quant, %b_scalar, %alpha):
%a_dequant = aten::dequantize(%a_quant)
%r = aten::add_(%a_dequant, %b_scalar, %alpha)
return (%r) )";
std::string quantized_add_scalar_out = R"(
graph(%a_quant, %b_scalar, %alpha):
%r = quantized::add_scalar_out(%a_quant, %b_scalar, %a_quant)
return (%r) )";
// quantized::add_scalar_relu
std::string add_scalar_relu = R"(
graph(%a_quant, %b_scalar, %alpha):
%a_dequant = aten::dequantize(%a_quant)
%r_add = aten::add(%a_dequant, %b_scalar, %alpha)
// quantized::add_scalar_relu -- fusing quantized::add_scalar
// and aten::relu
auto quantized_add_scalar_relu_pattern = R"(
graph(%a_quant, %b_scalar):
%r_add = quantized::add_scalar(%a_quant, %b_scalar)
%r = aten::relu(%r_add)
return (%r) )";
std::string add_scalar_inplace_relu = R"(
graph(%a_quant, %b_scalar, %alpha):
%a_dequant = aten::dequantize(%a_quant)
%r_add = aten::add(%a_dequant, %b_scalar, %alpha)
auto quantized_add_scalar_inplace_relu_pattern = R"(
graph(%a_quant, %b_scalar):
%r_add = quantized::add_scalar(%a_quant, %b_scalar)
%r = aten::relu_(%r_add)
return (%r) )";
std::string quantized_add_scalar_relu = R"(
graph(%a_quant, %b_scalar, %alpha):
auto quantized_add_scalar_relu_replacement = R"(
graph(%a_quant, %b_scalar):
%r = quantized::add_scalar_relu(%a_quant, %b_scalar)
return (%r) )";
// quantized::add_scalar_relu_out
std::string inplace_add_scalar_relu = R"(
graph(%a_quant, %b_scalar, %alpha):
%a_dequant = aten::dequantize(%a_quant)
%r_add = aten::add_(%a_dequant, %b_scalar, %alpha)
// quantized::add_scalar_relu_out -- fusing quantized::add_scalarOut
// and aten::relu
auto quantized_add_scalar_relu_out_pattern = R"(
graph(%a_quant, %b_scalar):
%r_add = quantized::add_scalar_out(%a_quant, %b_scalar, %a_quant)
%r = aten::relu(%r_add)
return (%r) )";
std::string inplace_add_scalar_inplace_relu = R"(
graph(%a_quant, %b_scalar, %alpha):
%a_dequant = aten::dequantize(%a_quant)
%r_add = aten::add_(%a_dequant, %b_scalar, %alpha)
auto quantized_add_scalar_inplace_relu_out_pattern = R"(
graph(%a_quant, %b_scalar):
%r_add = quantized::add_scalar_out(%a_quant, %b_scalar, %a_quant)
%r = aten::relu_(%r_add)
return (%r) )";
std::string quantized_add_scalar_relu_out = R"(
graph(%a_quant, %b_scalar, %alpha):
auto quantized_add_scalar_relu_out_replacement = R"(
graph(%a_quant, %b_scalar):
%r = quantized::add_scalar_relu_out(%a_quant, %b_scalar, %a_quant)
return (%r) )";
@ -624,28 +638,19 @@ graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
%r = quantized::mul(%a_quant, %b_quant, %scale, %zero_point)
return (%r) )";
// quantized::mul_scalar
std::string mul_scalar = R"(
graph(%a_quant, %b_scalar):
%a_dequant = aten::dequantize(%a_quant)
%r = aten::mul(%a_dequant, %b_scalar)
return (%r) )";
auto mul_scalar = getBinaryOpScalarFusionInfo(
"aten::mul",
{"%b_scalar"},
"quantized::mul_scalar",
{"%b_scalar"},
{input_b_is_scalar});
std::string inplace_mul_scalar = R"(
graph(%a_quant, %b_scalar):
%a_dequant = aten::dequantize(%a_quant)
%r = aten::mul_(%a_dequant, %b_scalar)
return (%r) )";
std::string quantized_mul_scalar = R"(
graph(%a_quant, %b_scalar):
%r = quantized::mul_scalar(%a_quant, %b_scalar)
return (%r) )";
std::string quantized_mul_scalar_out = R"(
graph(%a_quant, %b_scalar):
%r = quantized::mul_scalar_out(%a_quant, %b_scalar, %a_quant)
return (%r) )";
auto mul_scalar_out = getBinaryOpScalarFusionInfo(
"aten::mul_",
{"%b_scalar"},
"quantized::mul_scalar_out",
{"%b_scalar", "%a_quant"},
{input_b_is_scalar});
// quantized::mul_relu
std::string mul_relu = R"(
@ -689,42 +694,40 @@ graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
%r = quantized::mul_relu(%a_quant, %b_quant, %scale, %zero_point)
return (%r) )";
// quantized::mul_scalar_relu
std::string mul_scalar_relu = R"(
// quantized::mul_scalar_relu -- fusing quantized::mul_scalar
// and aten::relu
auto quantized_mul_scalar_relu_pattern = R"(
graph(%a_quant, %b_scalar):
%a_dequant = aten::dequantize(%a_quant)
%r_mul = aten::mul(%a_dequant, %b_scalar)
%r_mul = quantized::mul_scalar(%a_quant, %b_scalar)
%r = aten::relu(%r_mul)
return (%r) )";
std::string mul_scalar_inplace_relu = R"(
auto quantized_mul_scalar_inplace_relu_pattern = R"(
graph(%a_quant, %b_scalar):
%a_dequant = aten::dequantize(%a_quant)
%r_mul = aten::mul(%a_dequant, %b_scalar)
%r_mul = quantized::mul_scalar(%a_quant, %b_scalar)
%r = aten::relu_(%r_mul)
return (%r) )";
std::string quantized_mul_scalar_relu = R"(
auto quantized_mul_scalar_relu_replacement = R"(
graph(%a_quant, %b_scalar):
%r = quantized::mul_scalar_relu(%a_quant, %b_scalar)
return (%r) )";
// quantized::mul_scalar_relu_out
std::string inplace_mul_scalar_relu = R"(
// quantized::mul_scalar_relu_out -- fusing quantized::mul_scalarOut
// and aten::relu
auto quantized_mul_scalar_relu_out_pattern = R"(
graph(%a_quant, %b_scalar):
%a_dequant = aten::dequantize(%a_quant)
%r_mul = aten::mul_(%a_dequant, %b_scalar)
%r_mul = quantized::mul_scalar_out(%a_quant, %b_scalar, %a_quant)
%r = aten::relu(%r_mul)
return (%r) )";
std::string inplace_mul_scalar_inplace_relu = R"(
auto quantized_mul_scalar_inplace_relu_out_pattern = R"(
graph(%a_quant, %b_scalar):
%a_dequant = aten::dequantize(%a_quant)
%r_mul = aten::mul_(%a_dequant, %b_scalar)
%r_mul = quantized::mul_scalar_out(%a_quant, %b_scalar, %a_quant)
%r = aten::relu_(%r_mul)
return (%r) )";
std::string quantized_mul_scalar_relu_out = R"(
auto quantized_mul_scalar_relu_out_replacement = R"(
graph(%a_quant, %b_scalar):
%r = quantized::mul_scalar_relu_out(%a_quant, %b_scalar, %a_quant)
return (%r) )";
@ -905,31 +908,22 @@ graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype)
inplace_add_inplace_relu,
quantized_add_relu,
{aten_add_alpha_is_one}},
// note that this must come before quantized::add_scalar
add_scalar,
add_scalar_out,
// note that these must come after quantized::add_scalar and
// quantized::add_scalar_out patterns
{"quantized::add_scalar_relu",
add_scalar_relu,
quantized_add_scalar_relu,
{aten_add_alpha_is_one, input_b_is_scalar}},
quantized_add_scalar_relu_pattern,
quantized_add_scalar_relu_replacement},
{"quantized::add_scalar_relu",
add_scalar_inplace_relu,
quantized_add_scalar_relu,
{aten_add_alpha_is_one, input_b_is_scalar}},
quantized_add_scalar_inplace_relu_pattern,
quantized_add_scalar_relu_replacement},
{"quantized::add_scalar_relu_out",
inplace_add_scalar_relu,
quantized_add_scalar_relu_out,
{aten_add_alpha_is_one, input_b_is_scalar}},
quantized_add_scalar_relu_out_pattern,
quantized_add_scalar_relu_out_replacement},
{"quantized::add_scalar_relu_out",
inplace_add_scalar_inplace_relu,
quantized_add_scalar_relu_out,
{aten_add_alpha_is_one, input_b_is_scalar}},
{"quantized::add_scalar",
add_scalar,
quantized_add_scalar,
{aten_add_alpha_is_one, input_b_is_scalar}},
{"quantized::add_scalar_out",
inplace_add_scalar,
quantized_add_scalar_out,
{aten_add_alpha_is_one, input_b_is_scalar}},
quantized_add_scalar_inplace_relu_out_pattern,
quantized_add_scalar_relu_out_replacement},
{"quantized::add", add, quantized_add, {aten_add_alpha_is_one}},
{"quantized::add", inplace_add, quantized_add, {aten_add_alpha_is_one}},
{"quantized::cat", cat, quantized_cat},
@ -940,30 +934,22 @@ graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype)
{"quantized::batch_norm_relu",
batch_norm_inplace_relu,
quantized_batch_norm_relu},
mul_scalar,
mul_scalar_out,
// note that these must come after quantized::mul_scalar and
// quantized::mul_scalar_out patterns
{"quantized::mul_scalar_relu",
mul_scalar_relu,
quantized_mul_scalar_relu,
{input_b_is_scalar}},
quantized_mul_scalar_relu_pattern,
quantized_mul_scalar_relu_replacement},
{"quantized::mul_scalar_relu",
mul_scalar_inplace_relu,
quantized_mul_scalar_relu,
{input_b_is_scalar}},
quantized_mul_scalar_inplace_relu_pattern,
quantized_mul_scalar_relu_replacement},
{"quantized::mul_scalar_relu_out",
inplace_mul_scalar_relu,
quantized_mul_scalar_relu_out,
{input_b_is_scalar}},
quantized_mul_scalar_relu_out_pattern,
quantized_mul_scalar_relu_out_replacement},
{"quantized::mul_scalar_relu_out",
inplace_mul_scalar_inplace_relu,
quantized_mul_scalar_relu_out,
{input_b_is_scalar}},
{"quantized::mul_scalar",
mul_scalar,
quantized_mul_scalar,
{input_b_is_scalar}},
{"quantized::mul_scalar",
inplace_mul_scalar,
quantized_mul_scalar_out,
{input_b_is_scalar}},
quantized_mul_scalar_inplace_relu_out_pattern,
quantized_mul_scalar_relu_out_replacement},
{"quantized::mul_relu", mul_relu, quantized_mul_relu},
{"quantized::mul_relu", mul_inplace_relu, quantized_mul_relu},
{"quantized::mul_relu", inplace_mul_relu, quantized_mul_relu},

View File

@ -308,7 +308,7 @@ RegisterOperators reg(
return 0;
},
aliasAnalysisFromSchema()),
// only used internally in range() translation
// only used internally in range() translation
Operator(
"aten::__range_length(int lo, int hi, int step) -> int",
[](Stack& stack) {