pytorch/torch/csrc/jit/passes/quantization/helper.cpp
2024-10-26 17:41:27 +00:00

798 lines
23 KiB
C++

#include <torch/csrc/jit/passes/quantization/helper.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
#include <utility>
namespace torch::jit {
using graph_rewrite_helper::getFuncName;
struct FuncArg {
std::string func_name;
int arg_index;
};
using AtenFuncArgs = std::vector<FuncArg>;
using CallFuncArgs = std::vector<FuncArg>;
// Lists of allowed quantizable operators
std::vector<std::string> _static_quantizable_call_funcs = {
"conv2d",
"linear",
"batch_norm",
"hardswish",
"elu",
"celu",
"layer_norm",
"group_norm",
"instance_norm",
"embedding_bag",
};
std::vector<std::string> _static_quantizable_aten_funcs = {
"conv1d",
"conv2d",
"conv3d",
"conv_transpose1d",
"conv_transpose2d",
"linear",
"hardswish",
"hardswish_",
"elu",
"elu_",
"celu",
"celu_",
"batch_norm",
"layer_norm",
"group_norm",
"instance_norm",
"embedding_bag",
};
std::vector<std::string> _dynamic_quantizable_call_funcs = {
"linear",
};
std::vector<std::string> _dynamic_quantizable_aten_funcs = {
"linear",
};
std::vector<std::string> _static_weight_only_quant_aten_funcs = {
"embedding_bag",
};
std::vector<std::string> _static_weight_only_quant_call_funcs = {
"embedding_bag",
};
// These are the prim::CallFunctions that doesn't require observation and
// have a single input Tensor
// example: `prim::CallFunction(%dropout, %input_tensor, ...)
// so we propagate observed property from %input_tensor to the
// output of the `prim::CallFunction`
// Also these ops doesn't do computation on the value of Tensor, the
// operation only depends on the shape of the Tensor
std::vector<std::string> _single_input_general_shape_call_funcs = {
"_max_pool1d",
"_max_pool2d",
"_max_pool3d",
"dropout",
"relu",
};
// Similar to prim::CallFunctions, there are aten ops that doesn't
// require observation and have a single input Tensor
// Also these ops doesn't do computation on the value of Tensor, the
// operation only depends on the shape of the Tensor
// e.g. `aten::flatten(%input_tensor, ...)`
std::vector<std::string> _single_input_general_shape_aten_funcs = {
"max_pool1d",
"max_pool2d",
"max_pool3d",
"flatten",
"max",
"min",
"dropout",
"reshape",
// Non-inplace resize is deprecated
"resize_",
"chunk",
"view",
"transpose",
"contiguous",
"permute",
"repeat",
"repeat_interleave",
"relu",
"relu_",
"squeeze",
"squeeze_",
"unsqueeze",
"unsqueeze_",
"detach",
"detach_",
"stack",
"__getitem__",
};
// Theses are prim::CallFunctions for ops that doesn't require observation and
// have a single input Tensor
// Also these ops do computation on the value of Tensor
// TODO: [Need verify] looks like we can quantize simple functionals that just
// call into aten functions
std::vector<std::string> _single_input_general_value_call_funcs = {
"avg_pool1d",
"avg_pool2d",
"avg_pool3d",
"adaptive_avg_pool1d",
"adaptive_avg_pool2d",
"adaptive_avg_pool3d",
"interpolate",
"upsample",
"upsample_bilinear",
"upsample_nearest",
"hardtanh",
"leaky_relu",
};
// Theses are aten functions for ops that doesn't require observation and
// have a single input Tensor
// Also these ops do computation on the value of Tensor
// e.g. `aten::avg_pool2d(%input_tensor, ...)`
std::vector<std::string> _single_input_general_value_aten_funcs = {
"avg_pool1d",
"avg_pool2d",
"avg_pool3d",
"adaptive_avg_pool1d",
"adaptive_avg_pool2d",
"adaptive_avg_pool3d",
"mean",
"upsample_nearest1d",
"upsample_nearest2d",
"upsample_nearest3d",
"upsample_linear1d",
"upsample_bilinear2d",
"upsample_trilinear3d",
"upsample_bicubic2d",
"clamp",
// "clamp_", // Enable when quantized `clamp_` is ready
"hardtanh",
"hardtanh_",
"leaky_relu",
"leaky_relu_",
};
std::vector<std::string> _clamp_funcs = {
"hardtanh",
"hardtanh_",
"clamp",
// "clamp_", // Enable when quantized `clamp_` is ready
};
const float _asym_scale = 1.0f / 256.0f;
const int _asym_zero_point = 0;
const float _sym_scale = 2.0f / 256.0f;
const int _sym_zero_point = 128;
// quantization parameters for ops with range 0 to 1
// for example: aten/src/ATen/native/quantized/cpu/qsigmoid.cpp
std::tuple<c10::QScheme, QParamVector> _per_tensor_asym_qparam =
std::make_tuple(
c10::kPerTensorAffine,
QParamVector(
{std::make_pair(".scale", IValue(_asym_scale)),
std::make_pair(".zero_point", IValue(_asym_zero_point)),
std::make_pair(".scalar_type", IValue(c10::kQUInt8))}));
// quantization parameters for ops with range -1 to 1
// for example: aten/src/ATen/native/quantized/cpu/qtanh.cpp
std::tuple<c10::QScheme, QParamVector> _per_tensor_sym_qparam = std::make_tuple(
c10::kPerTensorAffine,
QParamVector(
{std::make_pair(".scale", IValue(_sym_scale)),
std::make_pair(".zero_point", IValue(_sym_zero_point)),
std::make_pair(".scalar_type", IValue(c10::kQUInt8))}));
// Map from aten op symbol to the quantization parameters
// for the ops with fixed quantization parameters
std::unordered_map<NodeKind, std::tuple<c10::QScheme, QParamVector>>
_fixed_qparams_map = {
{Symbol::aten("hardsigmoid"), _per_tensor_asym_qparam},
{Symbol::aten("hardsigmoid_"), _per_tensor_asym_qparam},
{Symbol::aten("sigmoid"), _per_tensor_asym_qparam},
{Symbol::aten("sigmoid_"), _per_tensor_asym_qparam},
{Symbol::aten("tanh"), _per_tensor_sym_qparam},
{Symbol::aten("tanh_"), _per_tensor_sym_qparam},
};
// Special checks for ops that do not require observers for all input tensors.
// For each operator in this list observers are inserted for the input based
// on the index specified.
AtenFuncArgs _observe_inputs_aten_func = {};
CallFuncArgs _observe_inputs_call_func = {{"batch_norm", 1}};
// Aten functions for getting tensor information
std::vector<std::string> _tensor_info_funcs = {"size", "len", "dim", "numel"};
// Aten functions whose output will be quantized or not quantized depending
// on input tensor
std::vector<std::string> _propagate_quant_single_input_ops = {"cat"};
// Rules are slightly different for binary ops like `aten::add`, for these ops,
// if both of the inputs are Tensor, we'll quantize the output only if both of
// the inputs are quantized
// if the second input is a Scalar, we'll only look at the first input to decide
// if we need to quantize the output
std::vector<std::string> _propagate_quant_binary_ops = {
"add",
"add_",
"mul",
"mul_"};
// Check if `use` is an aten function of name `func_name` and if value
// `v` is the nth argument (if provided) of the function.
bool matchAtenFuncToUse(
const Use& use,
const std::string& func_name,
std::optional<int> n) {
Node* node = use.user;
return node->kind() == Symbol::aten(func_name) &&
(!n.has_value() || static_cast<size_t>(n.value()) == use.offset);
}
bool matchCallFuncToUse(
const Use& use,
const std::string& func_name,
std::optional<int> n) {
Node* node = use.user;
return node->kind() == prim::CallFunction &&
getFuncName(node->inputs()[0]) == func_name &&
(!n.has_value() || static_cast<size_t>(n.value()) == use.offset);
}
// Check any use of `v` matches the aten function call
// or CallFunction patterns
static bool matchArgPattern(
Value* v,
const AtenFuncArgs& aten_func_args,
const CallFuncArgs& call_func_args) {
for (const Use& u : v->uses()) {
for (const auto& func_arg : aten_func_args) {
if (matchAtenFuncToUse(u, func_arg.func_name, func_arg.arg_index)) {
return true;
}
}
for (const auto& func_arg : call_func_args) {
if (matchCallFuncToUse(u, func_arg.func_name, func_arg.arg_index)) {
return true;
}
}
}
return false;
}
// TODO add other op signatures.
bool isWeight(Value* v) {
bool result = matchArgPattern(
v,
// ate::embedding_bag(%weight, %input, %offsets, %scale_grad_by_freq,
// %mode_enum, %sparse, %per_sample_weights, %include_last_offset)
AtenFuncArgs(
{{"conv1d", 1},
{"conv2d", 1},
{"conv3d", 1},
{"conv_transpose1d", 1},
{"conv_transpose2d", 1},
{"linear", 1},
{"embedding_bag", 0}}),
// embedding_bag - prim::CallFunction(%func, %input.1, %weight,
// %offsets.1, %max_norm, %norm_type, %scale_grad_by_freq, %mode, %sparse,
// %per_sample_weights.1, %include_last_offset)
CallFuncArgs({{"linear", 2}, {"embedding_bag", 2}}));
return result;
}
bool isBiasOfConvOrLinear(Value* v) {
bool result = matchArgPattern(
v,
AtenFuncArgs(
{{"conv1d", 2},
{"conv2d", 2},
{"conv3d", 2},
{"conv_transpose1d", 2},
{"conv_transpose2d", 2},
{"linear", 2}}),
CallFuncArgs({{"linear", 3}}));
return result;
}
bool isEmbeddingBagNonInput(Value* v) {
bool result = matchArgPattern(
v,
AtenFuncArgs({{"embedding_bag", 2}, {"embedding_bag", 6}}),
CallFuncArgs({}));
return result;
}
std::optional<Use> getClampScalarInputUse(Value* v) {
for (const auto& use : v->uses()) {
for (const auto& aten_func : _clamp_funcs) {
if (matchAtenFuncToUse(use, aten_func, 1) ||
matchAtenFuncToUse(use, aten_func, 2)) {
return use;
}
}
}
return std::nullopt;
}
void cloneMethod(
Module& module,
const std::string& orig_method_name,
const std::string& new_method_name) {
const Function& method = module.get_method(orig_method_name).function();
auto graph = toGraphFunction(method).graph()->copy();
const auto& schema = method.getSchema();
const auto this_method_name =
c10::QualifiedName(*module.type()->name(), new_method_name);
auto copied = module._ivalue()->compilation_unit()->create_function(
this_method_name, std::move(graph));
module.type()->addMethod(copied);
copied->setSchema(schema);
}
std::vector<Value*> getPassThroughInputs(Value* v) {
Node* n = v->node();
if (isSingleInputGeneralCallFunction(n)) {
return {n->input(1)};
} else if (
isSingleInputGeneralAtenFunction(n) ||
(n->kind() == Symbol::aten("sort") && v->offset() == 0)) {
return {n->input(0)};
} else if (n->kind() == prim::If && n->outputs().size() == 1) {
std::vector<Value*> inputs;
for (Block* subblock : n->blocks()) {
if (alwaysRaisesException(subblock)) {
continue;
}
auto* output = subblock->outputs()[0];
inputs.push_back(output);
}
return inputs;
} else if (n->kind() == prim::ListUnpack || n->kind() == prim::TupleUnpack) {
// only propagate dequantize for Tensor
if (v->type()->isSubtypeOf(*TensorType::get())) {
return {n->input(0)};
} else {
return {};
}
} else if (
n->kind() == prim::ListConstruct &&
v->type()->isSubtypeOf(*ListType::ofTensors())) {
std::vector<Value*> inputs;
for (auto* v : n->inputs()) {
inputs.push_back(v);
}
return inputs;
} else if (n->kind() == prim::TupleConstruct) {
std::vector<Value*> inputs;
for (auto* input : n->inputs()) {
if (input->type()->isSubtypeOf(*TensorType::get())) {
inputs.push_back(input);
}
}
return inputs;
} else if (n->kind() == Symbol::aten("append")) {
std::vector<Value*> inputs;
for (auto* input : n->inputs()) {
inputs.push_back(input);
}
return inputs;
}
return {};
}
static std::vector<NodeKind> toAtenSymbol(
const std::vector<std::string>& func_names) {
std::vector<NodeKind> symbols;
std::transform(
func_names.begin(),
func_names.end(),
std::back_inserter(symbols),
Symbol::aten);
return symbols;
}
static bool isAtenFunc(Node* n, const std::vector<NodeKind>& aten_funcs) {
return std::find(aten_funcs.begin(), aten_funcs.end(), n->kind()) !=
aten_funcs.end();
}
static bool isAtenFunc(Node* n, const std::vector<std::string>& aten_funcs) {
const auto& symbols = toAtenSymbol(aten_funcs);
return isAtenFunc(n, symbols);
}
// TODO: factor out isCallFunc
static bool isFunctionNode(
Node* n,
const std::vector<std::string>& call_funcs,
const std::vector<std::string>& aten_funcs) {
bool is_func_node = isAtenFunc(n, aten_funcs);
if (n->kind() == prim::CallFunction) {
auto func_name = getFuncName(n->inputs()[0]);
is_func_node |=
std::find(call_funcs.begin(), call_funcs.end(), func_name) !=
call_funcs.end();
}
return is_func_node;
}
bool isSingleInputGeneralShapeAtenFunction(Node* n) {
return isAtenFunc(n, _single_input_general_shape_aten_funcs);
}
bool isSingleInputGeneralValueAtenFunction(Node* n) {
return isAtenFunc(n, _single_input_general_value_aten_funcs) ||
isBinaryOpWithScalarInput(n);
}
bool isSingleInputGeneralCallFunction(Node* n) {
static std::vector<std::string> single_input_general_call_funcs;
std::copy(
_single_input_general_shape_call_funcs.begin(),
_single_input_general_shape_call_funcs.end(),
std::back_inserter(single_input_general_call_funcs));
std::copy(
_single_input_general_value_call_funcs.begin(),
_single_input_general_value_call_funcs.end(),
std::back_inserter(single_input_general_call_funcs));
return isFunctionNode(
n,
/* call_funcs = */ single_input_general_call_funcs,
/* aten_funcs = */ {});
}
bool isSingleInputGeneralAtenFunction(Node* n) {
static std::vector<NodeKind> fixed_qparams_aten_funcs;
std::transform(
_fixed_qparams_map.begin(),
_fixed_qparams_map.end(),
std::back_inserter(fixed_qparams_aten_funcs),
[](auto pair) { return pair.first; });
return isSingleInputGeneralValueAtenFunction(n) ||
isSingleInputGeneralShapeAtenFunction(n) ||
isAtenFunc(n, fixed_qparams_aten_funcs);
}
bool isClamp(Node* n) {
return isAtenFunc(n, _clamp_funcs);
}
bool isTensorInfoNode(Node* n) {
return isAtenFunc(n, _tensor_info_funcs);
}
bool isPropagateQuantSingleInputOp(Node* n) {
return isAtenFunc(n, _propagate_quant_single_input_ops);
}
bool isPropagateQuantBinaryOp(Node* n) {
return isAtenFunc(n, _propagate_quant_binary_ops);
}
bool isPropagateQuantOp(Node* n) {
return isPropagateQuantSingleInputOp(n) || isPropagateQuantBinaryOp(n);
}
bool isBinaryOpWithScalarInput(Node* n) {
return isPropagateQuantBinaryOp(n) && isScalar(n->input(1));
}
std::optional<std::tuple<c10::QScheme, QParamVector>> getFixedQParams(Node* n) {
static std::vector<NodeKind> fixed_qparam_funcs;
std::transform(
_fixed_qparams_map.begin(),
_fixed_qparams_map.end(),
std::back_inserter(fixed_qparam_funcs),
[](const auto& pair) { return pair.first; });
if (isAtenFunc(n, fixed_qparam_funcs)) {
return _fixed_qparams_map.at(n->kind());
}
return std::nullopt;
}
bool userDefinedCallFunction(Node* n) {
return n->kind() == prim::CallFunction &&
!isSingleInputGeneralCallFunction(n) &&
!isFunctionNode(n, _static_quantizable_call_funcs, {});
}
bool isWeightOnlyStaticQuantOp(Node* n) {
return isFunctionNode(
n,
_static_weight_only_quant_call_funcs,
_static_weight_only_quant_aten_funcs);
}
bool nodeQuantizable(Node* n, QuantType quant_type) {
bool is_dynamic = quant_type == QuantType::DYNAMIC;
return isFunctionNode(
n,
/* call_funcs = */
is_dynamic ? _dynamic_quantizable_call_funcs
: _static_quantizable_call_funcs,
/* aten_funcs = */
is_dynamic ? _dynamic_quantizable_aten_funcs
: _static_quantizable_aten_funcs);
}
bool useQuantizable(const Use& use, QuantType quant_type) {
if (quant_type == QuantType::STATIC) {
for (const auto& func_input : _observe_inputs_aten_func) {
if (matchAtenFuncToUse(use, func_input.func_name, std::nullopt)) {
return use.offset == static_cast<size_t>(func_input.arg_index);
}
}
for (const auto& func_input : _observe_inputs_call_func) {
if (matchCallFuncToUse(use, func_input.func_name, std::nullopt)) {
return use.offset == static_cast<size_t>(func_input.arg_index);
}
}
}
return nodeQuantizable(use.user, quant_type);
}
std::shared_ptr<Graph> getCallFunctionGraph(Node* n) {
auto* func_node = n->input(0)->node();
auto func = func_node->output()->type()->expectRef<FunctionType>().function();
auto graphFunc = tryToGraphFunction(*func);
TORCH_CHECK(graphFunc, "Quantization only works for graph function");
return graphFunc->graph();
}
// Block helper functions
bool alwaysRaisesException(Block* block) {
for (Node* n : block->nodes()) {
if (n->kind() == prim::RaiseException) {
return true;
}
if (n->kind() == prim::If) {
bool exception = true;
for (Block* b : n->blocks()) {
exception &= alwaysRaisesException(b);
}
if (exception) {
return true;
}
}
}
return false;
}
// Check if a value in the graph is a Scalar value
bool isScalar(Value* v) {
auto iv = toIValue(v);
return v->type()->isSubtypeOf(*NumberType::get()) ||
(v->type()->isSubtypeOf(*TensorType::get()) && iv && iv->isTensor() &&
iv->toTensor().dim() == 0);
}
// =================== Graph/Module analysis helper functions ============
// Check if value is the input of the graph
bool hitGraphInput(Value* value) {
Graph* graph = value->owningGraph();
const auto& inputs = graph->inputs();
return std::find(inputs.begin(), inputs.end(), value) != inputs.end();
}
// Get the module access path for a Value representing a module instance
// by tracing back the GetAttr nodes and recording all the attribute
// names along the way.
// Assuming 'self.sub.basic_block.conv1',
// Input1: Value instance of conv1
// Input2: Value instance of self
// Output: ['sub', 'basic_block', 'conv1']
std::vector<std::string> getModuleAccessPath(Value* instance, Value* self) {
std::vector<std::string> path;
// Iterator to traverse back the GetAttr calls
Value* iter = instance;
// trace back the instance to recover the path of the submodule
while (!hitGraphInput(iter) && iter->node()->kind() == prim::GetAttr) {
Node* get_attr = iter->node();
// record the name of GetAttr
path.push_back(get_attr->s(attr::name));
// trace back the chain of GetAttr
iter = get_attr->inputs()[0];
}
TORCH_CHECK(
iter == self,
"Can't handle the access pattern of GetAttr "
" in getModuleAccessPath, traced back to:",
iter->debugName(),
" which is not self:",
self->debugName());
std::reverse(path.begin(), path.end());
return path;
}
// Assuming self.foo.bar.conv1,
// Input1: Module instance of self
// Input2: ['foo', 'bar', 'conv1']
// Output: Module instance of conv1
Module findChildModule(
const Module& module,
const std::vector<std::string>& path) {
Module m = module;
for (const auto& p : path) {
m = m.attr(p).toModule();
}
return m;
}
Module getInvokedModule(Module& module, Node* n, Value* self) {
auto* instance = n->inputs()[0];
auto path = getModuleAccessPath(instance, self);
return findChildModule(module, path);
}
std::optional<Module> getInvokedModuleOpt(
const Module& module,
Node* n,
Value* self) {
auto* instance = n->inputs()[0];
auto path = getModuleAccessPath(instance, self);
Module m = module;
for (const auto& p : path) {
if (m.attr(p).isModule()) {
m = m.attr(p).toModule();
} else {
return std::nullopt;
}
}
return m;
}
// ==================== filter functions for matches ==============
bool is_int_constant(
const Match& match,
const std::unordered_map<std::string, Value*>& vmap,
const std::string& vname,
int value) {
const auto& match_vmap = match.values_map;
auto v = toIValue(match_vmap.at(vmap.at(vname)));
return v && v->isInt() && v->toInt() == value;
}
static bool is_functional(
const Match& match,
const std::unordered_map<std::string, Value*>& vmap,
const std::string& vname,
const std::string& functional) {
const auto& match_vmap = match.values_map;
Value* v = match_vmap.at(vmap.at(vname));
return v->type()->cast<FunctionType>() && getFuncName(v) == functional;
}
std::string removeTorchMangle(const std::string& orig_name) {
static std::regex mangle_re("\\.___torch_mangle_\\d+");
auto qualified_name = std::regex_replace(orig_name, mangle_re, "");
return qualified_name;
}
std::optional<std::string> getModuleName(Value* value) {
auto type = value->type()->cast<ClassType>();
if (type && type->name()) {
return removeTorchMangle(type->name()->qualifiedName());
}
return std::nullopt;
}
static bool is_module(
const Match& match,
const std::unordered_map<std::string, Value*>& vmap,
const std::string& vname,
const std::string& module_qualified_name) {
const auto& match_vmap = match.values_map;
Value* v = match_vmap.at(vmap.at(vname));
auto module_name = getModuleName(v);
if (module_name.has_value()) {
return module_name.value() == module_qualified_name;
}
return false;
};
bool aten_add_alpha_is_one(
const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
return is_int_constant(match, vmap, "alpha", 1);
}
bool is_functional_relu(
const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
return is_functional(match, vmap, "relu", "relu");
}
bool is_relu_module(
const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
return is_module(
match, vmap, "relu", "__torch__.torch.nn.modules.activation.ReLU");
}
bool is_linear_module(
const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
return is_module(
match, vmap, "linear", "__torch__.torch.nn.modules.linear.Linear");
}
bool is_conv1d_module(
const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
return is_module(
match, vmap, "conv", "__torch__.torch.nn.modules.conv.Conv1d");
}
bool is_conv2d_module(
const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
return is_module(
match, vmap, "conv", "__torch__.torch.nn.modules.conv.Conv2d");
}
bool is_conv3d_module(
const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
return is_module(
match, vmap, "conv", "__torch__.torch.nn.modules.conv.Conv3d");
}
bool is_conv_transpose1d_module(
const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
return is_module(
match, vmap, "conv", "__torch__.torch.nn.modules.conv.ConvTranspose1d");
}
bool is_conv_transpose2d_module(
const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
return is_module(
match, vmap, "conv", "__torch__.torch.nn.modules.conv.ConvTranspose2d");
}
bool is_batchnorm2d_module(
const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
bool regnorm = is_module(
match,
vmap,
"batchnorm",
"__torch__.torch.nn.modules.batchnorm.BatchNorm2d");
bool naivenorm = is_module(
match,
vmap,
"batchnorm",
"__torch__.mobile_cv.arch.layers.batch_norm.NaiveSyncBatchNorm");
return (regnorm || naivenorm);
}
bool is_batchnorm3d_module(
const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
return is_module(
match,
vmap,
"batchnorm",
"__torch__.torch.nn.modules.batchnorm.BatchNorm3d");
}
} // namespace torch::jit